T / app.py
um41r's picture
Update app.py
061b1a8 verified
import gradio as gr
from PIL import Image
import torch
# ── Patch transformers BEFORE importing AutoModel ──────────────────────────
from transformers import PretrainedConfig
_original_pretrained_init = PretrainedConfig.__init__
def _patched_pretrained_init(self, *args, **kwargs):
if not hasattr(self, "forced_bos_token_id"):
self.forced_bos_token_id = kwargs.get("forced_bos_token_id", None)
_original_pretrained_init(self, *args, **kwargs)
PretrainedConfig.__init__ = _patched_pretrained_init
# ───────────────────────────────────────────────────────────────────────────
from transformers import AutoProcessor, AutoModelForCausalLM
device = "cuda" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
print(f"Loading Florence-2 on {device}...")
model = AutoModelForCausalLM.from_pretrained(
"microsoft/Florence-2-base",
torch_dtype=torch_dtype,
trust_remote_code=True,
low_cpu_mem_usage=False # ← fix: prevents meta tensor initialization
).to(device)
processor = AutoProcessor.from_pretrained(
"microsoft/Florence-2-base",
trust_remote_code=True
)
print("Model loaded successfully!")
TASK_PROMPTS = {
"Caption": "<CAPTION>",
"Detailed Caption": "<DETAILED_CAPTION>",
"More Detailed Caption": "<MORE_DETAILED_CAPTION>",
"Object Detection": "<OD>",
"Dense Region Caption": "<DENSE_REGION_CAPTION>",
"Region Proposal": "<REGION_PROPOSAL>",
"Caption to Phrase Grounding": "<CAPTION_TO_PHRASE_GROUNDING>",
"Referring Expression Segmentation": "<REFERRING_EXPRESSION_SEGMENTATION>",
"OCR": "<OCR>",
"OCR with Region": "<OCR_WITH_REGION>",
}
TASKS_REQUIRING_TEXT = {"Caption to Phrase Grounding", "Referring Expression Segmentation"}
def run_florence(image: Image.Image, task: str, text_input: str = ""):
if image is None:
return "⚠️ Please upload an image."
task_prompt = TASK_PROMPTS[task]
if task in TASKS_REQUIRING_TEXT:
if not text_input.strip():
return f"⚠️ Task '{task}' requires a text input. Please provide one."
prompt = task_prompt + text_input.strip()
else:
prompt = task_prompt
try:
inputs = processor(
text=prompt,
images=image,
return_tensors="pt"
).to(device, torch_dtype)
with torch.no_grad():
generated_ids = model.generate(
input_ids=inputs["input_ids"],
pixel_values=inputs["pixel_values"],
max_new_tokens=1024,
num_beams=3,
)
generated_text = processor.batch_decode(
generated_ids, skip_special_tokens=False
)[0]
parsed = processor.post_process_generation(
generated_text,
task=task_prompt,
image_size=(image.width, image.height)
)
return str(parsed)
except Exception as e:
return f"❌ Error during inference: {str(e)}"
def toggle_text_input(task):
return gr.update(visible=task in TASKS_REQUIRING_TEXT)
with gr.Blocks(title="Florence-2 Demo", theme=gr.themes.Soft()) as demo:
gr.Markdown(
"""
# πŸ–ΌοΈ Microsoft Florence-2-base
Multi-task vision model: captioning, OCR, object detection, segmentation, and more.
Upload an image, choose a task, and hit **Run**.
"""
)
with gr.Row():
with gr.Column(scale=1):
image_input = gr.Image(type="pil", label="Input Image")
task_dropdown = gr.Dropdown(
choices=list(TASK_PROMPTS.keys()),
value="Detailed Caption",
label="Task"
)
text_input = gr.Textbox(
label="Text Input (required for grounding / segmentation tasks)",
placeholder="e.g. 'a cat on the sofa'",
visible=False
)
run_btn = gr.Button("β–Ά Run", variant="primary")
with gr.Column(scale=1):
output = gr.Textbox(label="Output", lines=20)
task_dropdown.change(
fn=toggle_text_input,
inputs=task_dropdown,
outputs=text_input
)
run_btn.click(
fn=run_florence,
inputs=[image_input, task_dropdown, text_input],
outputs=output
)
gr.Examples(
examples=[
["https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/car.jpg", "Detailed Caption", ""],
["https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/car.jpg", "More Detailed Caption", ""],
["https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/car.jpg", "OCR", ""],
["https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/car.jpg", "Object Detection", ""],
["https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/car.jpg", "Dense Region Caption", ""],
],
inputs=[image_input, task_dropdown, text_input],
)
demo.launch()