Update app.py
Browse files
app.py
CHANGED
|
@@ -74,6 +74,10 @@ class Prober:
|
|
| 74 |
|
| 75 |
self.model = self.model.eval()
|
| 76 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 77 |
|
| 78 |
@torch.no_grad()
|
| 79 |
def probe(self, idx, re, search_by_sample_id: bool= True):
|
|
@@ -82,38 +86,40 @@ class Prober:
|
|
| 82 |
else:
|
| 83 |
img_path, target = self.df[self.df.image_id == idx][['file_path','bbox']].values[0]
|
| 84 |
img = Image.open(self.zipfile.open(img_path)).convert('RGB')
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
|
|
|
|
|
|
| 110 |
prober = Prober(
|
| 111 |
df_path = 'data/val-sim_metric.json',
|
| 112 |
dataset_path = "data/saiapr_tc-12.zip",
|
| 113 |
model_checkpoint= "cache/20211220_191132_refclef_32_512_resnet50_8_6_8_0.1_0.1_0.1_0_0.0001_0.0_12_4_90_1_0_0_0/best.ckpt"
|
| 114 |
)
|
| 115 |
|
| 116 |
-
demo = gr.Interface(fn=prober.probe, inputs=["number", "text"
|
| 117 |
|
| 118 |
demo.queue(concurrency_count=10)
|
| 119 |
demo.launch(debug=True)
|
|
|
|
| 74 |
|
| 75 |
self.model = self.model.eval()
|
| 76 |
|
| 77 |
+
def preview_image(self, idx):
|
| 78 |
+
img_path, target, = self.df.loc[idx][['file_path','bbox']].values
|
| 79 |
+
img = Image.open(self.zipfile.open(img_path)).convert('RGB')
|
| 80 |
+
return img
|
| 81 |
|
| 82 |
@torch.no_grad()
|
| 83 |
def probe(self, idx, re, search_by_sample_id: bool= True):
|
|
|
|
| 86 |
else:
|
| 87 |
img_path, target = self.df[self.df.image_id == idx][['file_path','bbox']].values[0]
|
| 88 |
img = Image.open(self.zipfile.open(img_path)).convert('RGB')
|
| 89 |
+
if re != "":
|
| 90 |
+
W0, H0 = img.size
|
| 91 |
+
sample = {
|
| 92 |
+
'image': img,
|
| 93 |
+
'image_size': (H0, W0), # image original size
|
| 94 |
+
'bbox': torch.tensor([copy(target)]),
|
| 95 |
+
'bbox_raw': torch.tensor([copy(target)]),
|
| 96 |
+
'mask': torch.ones((1, H0, W0), dtype=torch.float32), # visibiity mask
|
| 97 |
+
'mask_bbox': None, # target bbox mask
|
| 98 |
+
}
|
| 99 |
+
sample = self.transform(sample)
|
| 100 |
+
tok = self.tokenizer(re,
|
| 101 |
+
max_length=30,
|
| 102 |
+
return_tensors='pt',
|
| 103 |
+
truncation=True)
|
| 104 |
+
inn = {'image': torch.stack([sample['image']]),
|
| 105 |
+
'mask': torch.stack([sample['mask']]),
|
| 106 |
+
'tok': tok}
|
| 107 |
+
output = undo_box_transforms_batch(self.model(inn)[0],
|
| 108 |
+
[sample['tr_param']]).numpy().tolist()[0]
|
| 109 |
+
img1 = ImageDraw.Draw(img)
|
| 110 |
+
#img1.rectangle(target, outline ="#0000FF00", width=3)
|
| 111 |
+
img1.rectangle(output, outline ="#00FF0000", width=3)
|
| 112 |
+
return img
|
| 113 |
+
else:
|
| 114 |
+
return img
|
| 115 |
+
|
| 116 |
prober = Prober(
|
| 117 |
df_path = 'data/val-sim_metric.json',
|
| 118 |
dataset_path = "data/saiapr_tc-12.zip",
|
| 119 |
model_checkpoint= "cache/20211220_191132_refclef_32_512_resnet50_8_6_8_0.1_0.1_0.1_0_0.0001_0.0_12_4_90_1_0_0_0/best.ckpt"
|
| 120 |
)
|
| 121 |
|
| 122 |
+
demo = gr.Interface(fn=prober.probe, inputs=["number", "text"], outputs="image", live=True)
|
| 123 |
|
| 124 |
demo.queue(concurrency_count=10)
|
| 125 |
demo.launch(debug=True)
|