Spaces:
Runtime error
Runtime error
Fix labels
Browse files- app_caption.py +2 -2
- prismer_model.py +32 -14
app_caption.py
CHANGED
|
@@ -11,11 +11,11 @@ from prismer_model import Model
|
|
| 11 |
|
| 12 |
def create_demo():
|
| 13 |
model = Model()
|
| 14 |
-
|
| 15 |
with gr.Row():
|
| 16 |
with gr.Column():
|
| 17 |
-
model_name = gr.Dropdown(label='Model', choices=['Prismer-Base'], value='Prismer-Base')
|
| 18 |
image = gr.Image(label='Input', type='filepath')
|
|
|
|
| 19 |
run_button = gr.Button('Run')
|
| 20 |
with gr.Column(scale=1.5):
|
| 21 |
caption = gr.Text(label='Caption')
|
|
|
|
| 11 |
|
| 12 |
def create_demo():
|
| 13 |
model = Model()
|
| 14 |
+
model.mode = 'caption'
|
| 15 |
with gr.Row():
|
| 16 |
with gr.Column():
|
|
|
|
| 17 |
image = gr.Image(label='Input', type='filepath')
|
| 18 |
+
model_name = gr.Dropdown(label='Model', choices=['Prismer-Base, Prismer-Large'], value='Prismer-Base')
|
| 19 |
run_button = gr.Button('Run')
|
| 20 |
with gr.Column(scale=1.5):
|
| 21 |
caption = gr.Text(label='Caption')
|
prismer_model.py
CHANGED
|
@@ -58,7 +58,7 @@ def run_experts(image_path: str) -> tuple[str | None, ...]:
|
|
| 58 |
|
| 59 |
keys = ['depth', 'edge', 'normal', 'seg_coco', 'obj_detection', 'ocr_detection']
|
| 60 |
results = [pathlib.Path('prismer/helpers/labels') / key / 'helpers/images/image.png' for key in keys]
|
| 61 |
-
return tuple(path.as_posix()
|
| 62 |
|
| 63 |
|
| 64 |
class Model:
|
|
@@ -67,24 +67,42 @@ class Model:
|
|
| 67 |
self.model = None
|
| 68 |
self.tokenizer = None
|
| 69 |
self.exp_name = ''
|
|
|
|
| 70 |
|
| 71 |
def set_model(self, exp_name: str) -> None:
|
| 72 |
if exp_name == self.exp_name:
|
| 73 |
return
|
| 74 |
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 88 |
model.load_state_dict(state_dict)
|
| 89 |
model.eval()
|
| 90 |
|
|
|
|
| 58 |
|
| 59 |
keys = ['depth', 'edge', 'normal', 'seg_coco', 'obj_detection', 'ocr_detection']
|
| 60 |
results = [pathlib.Path('prismer/helpers/labels') / key / 'helpers/images/image.png' for key in keys]
|
| 61 |
+
return tuple(path.as_posix() for path in results)
|
| 62 |
|
| 63 |
|
| 64 |
class Model:
|
|
|
|
| 67 |
self.model = None
|
| 68 |
self.tokenizer = None
|
| 69 |
self.exp_name = ''
|
| 70 |
+
self.mode = ''
|
| 71 |
|
| 72 |
def set_model(self, exp_name: str) -> None:
|
| 73 |
if exp_name == self.exp_name:
|
| 74 |
return
|
| 75 |
|
| 76 |
+
if self.mode == 'caption':
|
| 77 |
+
config = {
|
| 78 |
+
'dataset': 'demo',
|
| 79 |
+
'data_path': 'prismer/helpers',
|
| 80 |
+
'label_path': 'prismer/helpers/labels',
|
| 81 |
+
'experts': ['depth', 'normal', 'seg_coco', 'edge', 'obj_detection', 'ocr_detection'],
|
| 82 |
+
'image_resolution': 480,
|
| 83 |
+
'prismer_model': 'prismer_base' if self.exp_name == 'Prismer-Base' else 'prismer_large',
|
| 84 |
+
'freeze': 'freeze_vision',
|
| 85 |
+
'prefix': 'A picture of',
|
| 86 |
+
}
|
| 87 |
+
|
| 88 |
+
model = PrismerCaption(config)
|
| 89 |
+
state_dict = torch.load(f'prismer/logging/caption_{exp_name}/pytorch_model.bin', map_location='cuda:0')
|
| 90 |
+
|
| 91 |
+
elif self.mode == 'vqa':
|
| 92 |
+
config = {
|
| 93 |
+
'dataset': 'demo',
|
| 94 |
+
'data_path': 'prismer/helpers',
|
| 95 |
+
'label_path': 'prismer/helpers/labels',
|
| 96 |
+
'experts': ['depth', 'normal', 'seg_coco', 'edge', 'obj_detection', 'ocr_detection'],
|
| 97 |
+
'image_resolution': 480,
|
| 98 |
+
'prismer_model': 'prismer_base' if self.exp_name == 'Prismer-Base' else 'prismer_large',
|
| 99 |
+
'freeze': 'freeze_vision',
|
| 100 |
+
'prefix': 'A picture of',
|
| 101 |
+
}
|
| 102 |
+
|
| 103 |
+
model = PrismerCaption(config)
|
| 104 |
+
state_dict = torch.load(f'prismer/logging/caption_{exp_name}/pytorch_model.bin', map_location='cuda:0')
|
| 105 |
+
|
| 106 |
model.load_state_dict(state_dict)
|
| 107 |
model.eval()
|
| 108 |
|