| import gradio as gr |
| from PIL import Image |
| import torch |
|
|
| |
| from transformers import PretrainedConfig |
|
|
| _original_pretrained_init = PretrainedConfig.__init__ |
|
|
| def _patched_pretrained_init(self, *args, **kwargs): |
| if not hasattr(self, "forced_bos_token_id"): |
| self.forced_bos_token_id = kwargs.get("forced_bos_token_id", None) |
| _original_pretrained_init(self, *args, **kwargs) |
|
|
| PretrainedConfig.__init__ = _patched_pretrained_init |
| |
|
|
| from transformers import AutoProcessor, AutoModelForCausalLM |
|
|
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 |
|
|
| print(f"Loading Florence-2 on {device}...") |
|
|
| model = AutoModelForCausalLM.from_pretrained( |
| "microsoft/Florence-2-base", |
| torch_dtype=torch_dtype, |
| trust_remote_code=True, |
| low_cpu_mem_usage=False |
| ).to(device) |
|
|
| processor = AutoProcessor.from_pretrained( |
| "microsoft/Florence-2-base", |
| trust_remote_code=True |
| ) |
|
|
| print("Model loaded successfully!") |
|
|
| TASK_PROMPTS = { |
| "Caption": "<CAPTION>", |
| "Detailed Caption": "<DETAILED_CAPTION>", |
| "More Detailed Caption": "<MORE_DETAILED_CAPTION>", |
| "Object Detection": "<OD>", |
| "Dense Region Caption": "<DENSE_REGION_CAPTION>", |
| "Region Proposal": "<REGION_PROPOSAL>", |
| "Caption to Phrase Grounding": "<CAPTION_TO_PHRASE_GROUNDING>", |
| "Referring Expression Segmentation": "<REFERRING_EXPRESSION_SEGMENTATION>", |
| "OCR": "<OCR>", |
| "OCR with Region": "<OCR_WITH_REGION>", |
| } |
|
|
| TASKS_REQUIRING_TEXT = {"Caption to Phrase Grounding", "Referring Expression Segmentation"} |
|
|
|
|
| def run_florence(image: Image.Image, task: str, text_input: str = ""): |
| if image is None: |
| return "β οΈ Please upload an image." |
|
|
| task_prompt = TASK_PROMPTS[task] |
|
|
| if task in TASKS_REQUIRING_TEXT: |
| if not text_input.strip(): |
| return f"β οΈ Task '{task}' requires a text input. Please provide one." |
| prompt = task_prompt + text_input.strip() |
| else: |
| prompt = task_prompt |
|
|
| try: |
| inputs = processor( |
| text=prompt, |
| images=image, |
| return_tensors="pt" |
| ).to(device, torch_dtype) |
|
|
| with torch.no_grad(): |
| generated_ids = model.generate( |
| input_ids=inputs["input_ids"], |
| pixel_values=inputs["pixel_values"], |
| max_new_tokens=1024, |
| num_beams=3, |
| ) |
|
|
| generated_text = processor.batch_decode( |
| generated_ids, skip_special_tokens=False |
| )[0] |
|
|
| parsed = processor.post_process_generation( |
| generated_text, |
| task=task_prompt, |
| image_size=(image.width, image.height) |
| ) |
|
|
| return str(parsed) |
|
|
| except Exception as e: |
| return f"β Error during inference: {str(e)}" |
|
|
|
|
| def toggle_text_input(task): |
| return gr.update(visible=task in TASKS_REQUIRING_TEXT) |
|
|
|
|
| with gr.Blocks(title="Florence-2 Demo", theme=gr.themes.Soft()) as demo: |
| gr.Markdown( |
| """ |
| # πΌοΈ Microsoft Florence-2-base |
| Multi-task vision model: captioning, OCR, object detection, segmentation, and more. |
| Upload an image, choose a task, and hit **Run**. |
| """ |
| ) |
|
|
| with gr.Row(): |
| with gr.Column(scale=1): |
| image_input = gr.Image(type="pil", label="Input Image") |
|
|
| task_dropdown = gr.Dropdown( |
| choices=list(TASK_PROMPTS.keys()), |
| value="Detailed Caption", |
| label="Task" |
| ) |
|
|
| text_input = gr.Textbox( |
| label="Text Input (required for grounding / segmentation tasks)", |
| placeholder="e.g. 'a cat on the sofa'", |
| visible=False |
| ) |
|
|
| run_btn = gr.Button("βΆ Run", variant="primary") |
|
|
| with gr.Column(scale=1): |
| output = gr.Textbox(label="Output", lines=20) |
|
|
| task_dropdown.change( |
| fn=toggle_text_input, |
| inputs=task_dropdown, |
| outputs=text_input |
| ) |
|
|
| run_btn.click( |
| fn=run_florence, |
| inputs=[image_input, task_dropdown, text_input], |
| outputs=output |
| ) |
|
|
| gr.Examples( |
| examples=[ |
| ["https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/car.jpg", "Detailed Caption", ""], |
| ["https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/car.jpg", "More Detailed Caption", ""], |
| ["https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/car.jpg", "OCR", ""], |
| ["https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/car.jpg", "Object Detection", ""], |
| ["https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/car.jpg", "Dense Region Caption", ""], |
| ], |
| inputs=[image_input, task_dropdown, text_input], |
| ) |
|
|
| demo.launch() |