Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import PIL.Image | |
| import transformers | |
| from transformers import PaliGemmaForConditionalGeneration, PaliGemmaProcessor | |
| import torch | |
| import os | |
| import string | |
| import functools | |
| import re | |
| import numpy as np | |
| import spaces | |
| # Model IDs | |
| MODEL_IDS = { | |
| "paligemma-3b-ft-widgetcap-waveui-448": "agentsea/paligemma-3b-ft-widgetcap-waveui-448", | |
| "paligemma-3b-ft-waveui-896": "agentsea/paligemma-3b-ft-waveui-896" | |
| } | |
| COLORS = ['#4285f4', '#db4437', '#f4b400', '#0f9d58', '#e48ef1'] | |
| # Device configuration | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # Load models and processors | |
| models = {name: PaliGemmaForConditionalGeneration.from_pretrained(model_id).eval().to(device) | |
| for name, model_id in MODEL_IDS.items()} | |
| processors = {name: PaliGemmaProcessor.from_pretrained(processor_id) | |
| for name, processor_id in MODEL_IDS.items()} | |
| ###### Transformers Inference | |
| def infer( | |
| image: PIL.Image.Image, | |
| text: str, | |
| max_new_tokens: int, | |
| model_choice: str | |
| ) -> str: | |
| model = models[model_choice] | |
| processor = processors[model_choice] | |
| inputs = processor(text=text, images=image, return_tensors="pt").to(device) | |
| with torch.inference_mode(): | |
| generated_ids = model.generate( | |
| **inputs, | |
| max_new_tokens=max_new_tokens, | |
| do_sample=False | |
| ) | |
| result = processor.batch_decode(generated_ids, skip_special_tokens=True) | |
| return result[0][len(text):].lstrip("\n") | |
| def parse_segmentation(input_image, input_text, model_choice): | |
| out = infer(input_image, input_text, max_new_tokens=100, model_choice=model_choice) | |
| objs = extract_objs(out.lstrip("\n"), input_image.size[0], input_image.size[1], unique_labels=True) | |
| labels = set(obj.get('name') for obj in objs if obj.get('name')) | |
| color_map = {l: COLORS[i % len(COLORS)] for i, l in enumerate(labels)} | |
| highlighted_text = [(obj['content'], obj.get('name')) for obj in objs] | |
| annotated_img = ( | |
| input_image, | |
| [ | |
| ( | |
| obj['mask'] if obj.get('mask') is not None else obj['xyxy'], | |
| obj['name'] or '', | |
| ) | |
| for obj in objs | |
| if 'mask' in obj or 'xyxy' in obj | |
| ], | |
| ) | |
| has_annotations = bool(annotated_img[1]) | |
| return annotated_img | |
| ######## Demo | |
| INTRO_TEXT = """## PaliGemma WaveUI\n\n | |
| Two fine-tuned models on the [WaveUI dataset](https://huggingface.co/datasets/agentsea/wave-ui) from different bases:\n\n | |
| - [paligemma-3b-ft-widgetcap-waveui-448](https://huggingface.co/agentsea/paligemma-3b-ft-widgetcap-waveui-448) | |
| - [paligemma-3b-ft-waveui-896](https://huggingface.co/agentsea/paligemma-3b-ft-waveui-896) | |
| Note:\n\n | |
| - the task they were fine-tuned on was detection, so it may not generalize to other tasks. | |
| Usage: write the task keyword "detect" before the element you want the model to detect. For example, "detect profile picture". | |
| """ | |
| with gr.Blocks(css="style.css") as demo: | |
| gr.Markdown(INTRO_TEXT) | |
| with gr.Tab("Detection"): | |
| model_choice = gr.Dropdown(label="Select Model", choices=list(MODEL_IDS.keys())) | |
| image = gr.Image(type="pil") | |
| seg_input = gr.Text(label="Detect instruction (e.g. 'detect sign in button')") | |
| seg_btn = gr.Button("Submit") | |
| annotated_image = gr.AnnotatedImage(label="Output") | |
| examples = [["./airbnb.jpg", "detect 'Amazing pools' button"]] | |
| gr.Examples( | |
| examples=examples, | |
| inputs=[image, seg_input], | |
| ) | |
| seg_inputs = [ | |
| image, | |
| seg_input, | |
| model_choice | |
| ] | |
| seg_outputs = [ | |
| annotated_image | |
| ] | |
| seg_btn.click( | |
| fn=parse_segmentation, | |
| inputs=seg_inputs, | |
| outputs=seg_outputs, | |
| ) | |
| _SEGMENT_DETECT_RE = re.compile( | |
| r'(.*?)' + | |
| r'<loc(\d{4})>' * 4 + r'\s*' + | |
| '(?:%s)?' % (r'<seg(\d{3})>' * 16) + | |
| r'\s*([^;<>]+)? ?(?:; )?', | |
| ) | |
| def extract_objs(text, width, height, unique_labels=False): | |
| """Returns objs for a string with "<loc>" and "<seg>" tokens.""" | |
| objs = [] | |
| seen = set() | |
| while text: | |
| m = _SEGMENT_DETECT_RE.match(text) | |
| if not m: | |
| break | |
| print("m", m) | |
| gs = list(m.groups()) | |
| before = gs.pop(0) | |
| name = gs.pop() | |
| y1, x1, y2, x2 = [int(x) / 1024 for x in gs[:4]] | |
| y1, x1, y2, x2 = map(round, (y1*height, x1*width, y2*height, x2*width)) | |
| mask = None | |
| content = m.group() | |
| if before: | |
| objs.append(dict(content=before)) | |
| content = content[len(before):] | |
| while unique_labels and name in seen: | |
| name = (name or '') + "'" | |
| seen.add(name) | |
| objs.append(dict( | |
| content=content, xyxy=(x1, y1, x2, y2), mask=mask, name=name)) | |
| text = text[len(before) + len(content):] | |
| if text: | |
| objs.append(dict(content=text)) | |
| return objs | |
| ######### | |
| if __name__ == "__main__": | |
| demo.queue(max_size=10).launch(debug=True) |