James040 commited on
Commit
234780c
·
verified ·
1 Parent(s): 0f0e73d

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +94 -0
app.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from PIL import Image
4
+ from transformers import AutoProcessor, AutoModelForCausalLM
5
+
6
+ # Use CPU as requested
7
+ device = "cpu"
8
+
9
+ def load_vlm(model_name):
10
+ """Helper to load model and processor."""
11
+ try:
12
+ print(f"Loading {model_name}...")
13
+ model = AutoModelForCausalLM.from_pretrained(
14
+ f'microsoft/{model_name}',
15
+ trust_remote_code=True
16
+ ).to(device).eval()
17
+ processor = AutoProcessor.from_pretrained(
18
+ f'microsoft/{model_name}',
19
+ trust_remote_code=True
20
+ )
21
+ return model, processor
22
+ except Exception as e:
23
+ print(f"Error loading {model_name}: {e}")
24
+ return None, None
25
+
26
+ # Load both models
27
+ model_base, proc_base = load_vlm('Florence-2-base')
28
+ model_large, proc_large = load_vlm('Florence-2-large')
29
+
30
+ def describe_image(uploaded_image, model_choice):
31
+ if uploaded_image is None:
32
+ return "Please upload an image."
33
+
34
+ # Select model based on UI choice
35
+ if model_choice == "Florence-2-base":
36
+ model, processor = model_base, proc_base
37
+ else:
38
+ model, processor = model_large, proc_large
39
+
40
+ if model is None:
41
+ return f"{model_choice} failed to load."
42
+
43
+ if not isinstance(uploaded_image, Image.Image):
44
+ uploaded_image = Image.fromarray(uploaded_image)
45
+
46
+ # Core generation logic
47
+ inputs = processor(text="<MORE_DETAILED_CAPTION>", images=uploaded_image, return_tensors="pt").to(device)
48
+
49
+ with torch.no_grad():
50
+ generated_ids = model.generate(
51
+ input_ids=inputs["input_ids"],
52
+ pixel_values=inputs["pixel_values"],
53
+ max_new_tokens=1024,
54
+ num_beams=3,
55
+ do_sample=False,
56
+ )
57
+
58
+ generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
59
+ result = processor.post_process_generation(
60
+ generated_text,
61
+ task="<MORE_DETAILED_CAPTION>",
62
+ image_size=(uploaded_image.width, uploaded_image.height)
63
+ )
64
+
65
+ return result["<MORE_DETAILED_CAPTION>"]
66
+
67
+ # Simplified Gradio Layout
68
+ css = ".submit-btn { background-color: #4682B4 !important; color: white !important; }"
69
+
70
+ with gr.Blocks(theme="bethecloud/storj_theme", css=css) as demo:
71
+ gr.Markdown("# **Florence-2 Models Image Captions**")
72
+ gr.Markdown("> Select the model to use. **Base** is faster; **Large** is more accurate.")
73
+
74
+ with gr.Row():
75
+ with gr.Column():
76
+ image_input = gr.Image(label="Upload Image", type="pil")
77
+ model_choice = gr.Radio(
78
+ choices=["Florence-2-base", "Florence-2-large"],
79
+ label="Model Choice",
80
+ value="Florence-2-base"
81
+ )
82
+ generate_btn = gr.Button("Generate Caption", elem_classes="submit-btn")
83
+
84
+ with gr.Column():
85
+ output = gr.Textbox(label="Generated Caption", lines=6, interactive=True)
86
+
87
+ generate_btn.click(
88
+ fn=describe_image,
89
+ inputs=[image_input, model_choice],
90
+ outputs=output
91
+ )
92
+
93
+ if __name__ == "__main__":
94
+ demo.launch()