Florian commited on
Commit
0c57df9
·
1 Parent(s): f04ff47

first commit

Browse files
Files changed (1) hide show
  1. app.py +272 -0
app.py ADDED
@@ -0,0 +1,272 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+
4
+ import torch
5
+ from PIL import Image
6
+ from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration
7
+
8
+ os.environ["GRADIO_TEMP_DIR"] = (
9
+ "/home/agent_vision@BEIJAFLORE.COM/fmorel/CoVT-main/CoVT-main/gradio/temp"
10
+ )
11
+ import gradio as gr
12
+
13
+ # ================= Configuration Area =================
14
+ # You can change these defaults as you like
15
+ DEFAULT_MODEL_NAME = "Wakals/CoVT-7B-seg_depth_dino"
16
+ DEFAULT_CKPT_PATH = None # Or set to your local checkpoint path
17
+ # ======================================================
18
+
19
+ # Global cache for model and processor to avoid re-loading every call
20
+ _cached_model = None
21
+ _cached_processor = None
22
+
23
+
24
+ def load_model_and_processor(
25
+ model_name: str,
26
+ ckpt: str = None,
27
+ ):
28
+ """
29
+ Load a single CoVT-7B model and its corresponding processor.
30
+ """
31
+ if ckpt is not None:
32
+ print(f"Loading model from ckpt: {ckpt}")
33
+ model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
34
+ ckpt, torch_dtype=torch.bfloat16, device_map="auto"
35
+ ).eval()
36
+ processor = AutoProcessor.from_pretrained(
37
+ ckpt, min_pixels=256 * 28 * 28, max_pixels=1280 * 28 * 28
38
+ )
39
+ else:
40
+ print(f"Loading model from hub: {model_name}")
41
+ model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
42
+ model_name, torch_dtype=torch.bfloat16, device_map="auto"
43
+ ).eval()
44
+ processor = AutoProcessor.from_pretrained(
45
+ model_name, min_pixels=256 * 28 * 28, max_pixels=1280 * 28 * 28
46
+ )
47
+
48
+ return model, processor
49
+
50
+
51
+ def get_cached_model_and_processor(
52
+ model_name: str = DEFAULT_MODEL_NAME,
53
+ ckpt: str = DEFAULT_CKPT_PATH,
54
+ ):
55
+ """
56
+ Lazy-load and cache the model and processor so they are not reloaded every request.
57
+ """
58
+ global _cached_model, _cached_processor
59
+
60
+ # If already loaded, just return them
61
+ if _cached_model is not None and _cached_processor is not None:
62
+ return _cached_model, _cached_processor
63
+
64
+ # Otherwise load and cache
65
+ _cached_model, _cached_processor = load_model_and_processor(
66
+ model_name=model_name,
67
+ ckpt=ckpt,
68
+ )
69
+ return _cached_model, _cached_processor
70
+
71
+
72
+ def run_single_inference(
73
+ model,
74
+ processor,
75
+ image, # can be either a PIL.Image or a path string
76
+ question: str,
77
+ max_new_tokens: int = 512,
78
+ temperature: float = 0.0,
79
+ top_p: float = 0.9,
80
+ do_sample: bool = False,
81
+ seed: int = 42,
82
+ ):
83
+ """
84
+ Single inference: given one image and one question, return answer and elapsed time.
85
+ """
86
+ # 1) Prepare conversation
87
+ # For Gradio we usually get a PIL image, but we also support a path string for compatibility.
88
+ if isinstance(image, str):
89
+ pil_image = Image.open(image).convert("RGB")
90
+ image_ref = image # path for the "image" field
91
+ elif isinstance(image, Image.Image):
92
+ pil_image = image.convert("RGB")
93
+ # When using PIL image in chat template, you can pass a placeholder
94
+ # and rely on 'images' argument in processor; here we still need a "dummy" reference.
95
+ image_ref = (
96
+ "gradio_image" # this is not used as a real path, just a placeholder
97
+ )
98
+ else:
99
+ raise ValueError("image must be a PIL.Image or a path string.")
100
+
101
+ messages = [
102
+ {
103
+ "role": "user",
104
+ "content": [
105
+ {"type": "image", "image": image_ref},
106
+ {"type": "text", "text": question},
107
+ ],
108
+ }
109
+ ]
110
+
111
+ # 2) Apply chat template
112
+ prompt = processor.apply_chat_template(
113
+ messages, tokenize=False, add_generation_prompt=True
114
+ )
115
+
116
+ # 3) Encode image and text
117
+ inputs = processor(text=[prompt], images=[pil_image], return_tensors="pt")
118
+
119
+ # Move inputs to the same device as the model
120
+ device = model.device
121
+ inputs = {
122
+ k: (v.to(device) if isinstance(v, torch.Tensor) else v)
123
+ for k, v in inputs.items()
124
+ }
125
+
126
+ # 3.5) Set random seed and generator for reproducibility when sampling
127
+ seed = int(seed)
128
+ torch.manual_seed(seed)
129
+ if torch.cuda.is_available():
130
+ torch.cuda.manual_seed_all(seed)
131
+ try:
132
+ generator = torch.Generator(device=device)
133
+ except TypeError:
134
+ generator = torch.Generator()
135
+ generator.manual_seed(seed)
136
+
137
+ # 4) Timing + generation
138
+ if device.type == "cuda":
139
+ torch.cuda.empty_cache()
140
+ torch.cuda.synchronize()
141
+
142
+ start = time.time()
143
+ with torch.no_grad():
144
+ generated_ids = model.generate(
145
+ **inputs,
146
+ max_new_tokens=max_new_tokens,
147
+ temperature=temperature,
148
+ top_p=top_p,
149
+ do_sample=do_sample,
150
+ generator=generator,
151
+ pad_token_id=processor.tokenizer.eos_token_id,
152
+ eos_token_id=processor.tokenizer.eos_token_id,
153
+ )
154
+ if device.type == "cuda":
155
+ torch.cuda.synchronize()
156
+ end = time.time()
157
+
158
+ elapsed = end - start
159
+
160
+ # 5) Decode only newly generated tokens
161
+ input_len = inputs["input_ids"].shape[1]
162
+ new_tokens = generated_ids[0, input_len:]
163
+ answer = processor.decode(new_tokens, skip_special_tokens=True)
164
+
165
+ return answer, elapsed
166
+
167
+
168
+ def gradio_inference(
169
+ image,
170
+ question,
171
+ max_new_tokens,
172
+ temperature,
173
+ top_p,
174
+ seed,
175
+ ):
176
+ """
177
+ Wrapper function for Gradio that calls the inference logic and returns answer + time cost.
178
+ """
179
+ if image is None:
180
+ return "Please upload an image.", 0.0
181
+
182
+ # Get (or load) model and processor
183
+ model, processor = get_cached_model_and_processor()
184
+
185
+ # Run inference
186
+ answer, elapsed = run_single_inference(
187
+ model=model,
188
+ processor=processor,
189
+ image=image, # filepath string from Gradio
190
+ question=question,
191
+ max_new_tokens=int(max_new_tokens),
192
+ temperature=float(temperature),
193
+ top_p=float(top_p),
194
+ do_sample=(temperature > 0.0),
195
+ seed=int(seed),
196
+ )
197
+
198
+ return answer, elapsed
199
+
200
+
201
+ # ===================== Gradio UI =====================
202
+
203
+
204
+ def build_demo():
205
+ with gr.Blocks() as demo:
206
+ gr.Markdown(
207
+ "# CoVT-7B Gradio Demo\n"
208
+ "Upload an image and input a question to run visual question answering."
209
+ )
210
+
211
+ with gr.Row():
212
+ with gr.Column():
213
+ image_input = gr.Image(label="Input Image", type="pil")
214
+ question_input = gr.Textbox(label="Question", value="", lines=2)
215
+ max_new_tokens = gr.Slider(
216
+ label="max_new_tokens", minimum=1, maximum=1024, value=512, step=1
217
+ )
218
+ temperature = gr.Slider(
219
+ label="temperature", minimum=0.0, maximum=1.0, value=0.0, step=0.01
220
+ )
221
+ top_p = gr.Slider(
222
+ label="top_p", minimum=0.1, maximum=1.0, value=0.9, step=0.01
223
+ )
224
+ seed = gr.Slider(
225
+ label="random_seed", minimum=0, maximum=1000, value=42, step=1
226
+ )
227
+
228
+ gr.Markdown("### Example")
229
+ example_image_path = os.path.abspath(
230
+ os.path.join(
231
+ os.path.dirname(__file__), "..", "assets", "clouds.png"
232
+ )
233
+ )
234
+ example_image = Image.open(example_image_path).convert("RGB")
235
+ gr.Examples(
236
+ examples=[
237
+ [
238
+ example_image,
239
+ "Describe the scene in the picture in detail, and find out how many clouds are in the sky. Use segmentation, depth map, and perception feature information of the image to answer this question.",
240
+ ]
241
+ ],
242
+ inputs=[image_input, question_input],
243
+ examples_per_page=1,
244
+ )
245
+ # -----------------------------------------
246
+
247
+ run_button = gr.Button("Run Inference")
248
+
249
+ with gr.Column():
250
+ answer_output = gr.Textbox(label="Answer", lines=10)
251
+ elapsed_output = gr.Number(label="Elapsed time (seconds)")
252
+
253
+ run_button.click(
254
+ fn=gradio_inference,
255
+ inputs=[
256
+ image_input,
257
+ question_input,
258
+ max_new_tokens,
259
+ temperature,
260
+ top_p,
261
+ seed,
262
+ ],
263
+ outputs=[answer_output, elapsed_output],
264
+ )
265
+
266
+ return demo
267
+
268
+
269
+ if __name__ == "__main__":
270
+ demo = build_demo()
271
+ # You can set share=True if you want a public link
272
+ demo.launch()