markpbaggett commited on
Commit
f5d8b8a
·
1 Parent(s): 1216cbc
Files changed (2) hide show
  1. app.py +168 -74
  2. requirements.txt +2 -0
app.py CHANGED
@@ -1,39 +1,115 @@
1
  import gradio as gr
2
  import torch
3
  from transformers import Qwen2VLForConditionalGeneration, AutoTokenizer, AutoProcessor
4
- from qwen_vl_utils import process_vision_info
5
- import json
6
  from PIL import Image
 
 
 
 
 
 
 
 
 
7
 
8
  # Global variables to store model and processor
9
  model = None
10
  processor = None
11
  tokenizer = None
12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  def load_model():
14
- """Load the Qwen2.5-VL model and processor"""
15
  global model, processor, tokenizer
16
 
17
  if model is None:
18
- print("Loading Qwen2.5-VL-7B-Instruct model...")
19
-
20
- # Load model - using smaller 7B version for Spaces compatibility
21
- model = Qwen2VLForConditionalGeneration.from_pretrained(
22
- "Qwen/Qwen2.5-VL-7B-Instruct",
23
- torch_dtype=torch.bfloat16,
24
- device_map="auto"
25
- )
26
-
27
- # Load processor and tokenizer
28
- processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct")
29
- tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct")
30
-
31
- print("Model loaded successfully!")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
 
33
  return model, processor, tokenizer
34
 
35
  def generate_metadata(image, metadata_type):
36
- """Generate metadata for the uploaded image"""
37
  if image is None:
38
  return "Please upload an image first."
39
 
@@ -54,7 +130,7 @@ def generate_metadata(image, metadata_type):
54
 
55
  prompt = prompts.get(metadata_type, prompts["Basic Description"])
56
 
57
- # Prepare the conversation format expected by Qwen2.5-VL
58
  messages = [
59
  {
60
  "role": "user",
@@ -68,39 +144,70 @@ def generate_metadata(image, metadata_type):
68
  }
69
  ]
70
 
71
- # Process the input
72
- text = processor.apply_chat_template(
73
- messages, tokenize=False, add_generation_prompt=True
74
- )
75
- image_inputs, video_inputs = process_vision_info(messages)
76
- inputs = processor(
77
- text=[text],
78
- images=image_inputs,
79
- videos=video_inputs,
80
- padding=True,
81
- return_tensors="pt",
82
- )
83
- inputs = inputs.to(model.device)
84
-
85
- # Generate response
86
- with torch.no_grad():
87
- generated_ids = model.generate(
88
- **inputs,
89
- max_new_tokens=512,
90
- temperature=0.7,
91
- do_sample=True,
92
- top_p=0.9,
93
- pad_token_id=tokenizer.pad_token_id
94
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
 
96
- generated_ids_trimmed = [
97
- out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
98
- ]
99
- output_text = processor.batch_decode(
100
- generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
101
- )[0]
102
-
103
- return output_text.strip()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104
 
105
  except Exception as e:
106
  return f"Error generating metadata: {str(e)}"
@@ -108,7 +215,6 @@ def generate_metadata(image, metadata_type):
108
  def create_interface():
109
  """Create the Gradio interface"""
110
 
111
- # Custom CSS for better styling
112
  css = """
113
  .metadata-container {
114
  background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
@@ -169,25 +275,6 @@ def create_interface():
169
  elem_classes=["output-text"]
170
  )
171
 
172
- # Example images section
173
- gr.HTML("<h3 style='text-align: center; margin-top: 30px;'>Try these example images:</h3>")
174
-
175
- with gr.Row():
176
- example_images = [
177
- ["examples/landscape.jpg", "Scene & Context"],
178
- ["examples/portrait.jpg", "Objects & People"],
179
- ["examples/food.jpg", "Basic Description"],
180
- ["examples/architecture.jpg", "Technical Analysis"]
181
- ]
182
-
183
- gr.Examples(
184
- examples=example_images,
185
- inputs=[image_input, metadata_type],
186
- outputs=output_text,
187
- fn=generate_metadata,
188
- cache_examples=False
189
- )
190
-
191
  # Event handlers
192
  generate_btn.click(
193
  fn=generate_metadata,
@@ -196,7 +283,7 @@ def create_interface():
196
  show_progress=True
197
  )
198
 
199
- # Auto-generate on image upload with basic description
200
  image_input.change(
201
  fn=lambda img: generate_metadata(img, "Basic Description") if img else "",
202
  inputs=[image_input],
@@ -207,21 +294,28 @@ def create_interface():
207
  gr.HTML("""
208
  <div style="text-align: center; padding: 20px; margin-top: 30px; border-top: 1px solid #eee;">
209
  <p style="color: #666;">
210
- This Space uses Qwen2.5-VL-7B-Instruct for intelligent image analysis and metadata generation.
211
  <br>Perfect for content management, SEO optimization, and accessibility improvements.
212
  </p>
 
 
 
213
  </div>
214
  """)
215
 
216
  return interface
217
 
218
- # Initialize the model when the app starts (optional - can be lazy loaded)
219
  def initialize_app():
220
  """Initialize the application"""
221
  print("Starting Image Metadata Generator...")
222
  print("Model will be loaded on first use to save resources.")
223
 
224
- # Create and launch interface
 
 
 
 
 
225
  interface = create_interface()
226
  return interface
227
 
 
1
  import gradio as gr
2
  import torch
3
  from transformers import Qwen2VLForConditionalGeneration, AutoTokenizer, AutoProcessor
 
 
4
  from PIL import Image
5
+ import json
6
+
7
+ # Try to import qwen_vl_utils, fallback if not available
8
+ try:
9
+ from qwen_vl_utils import process_vision_info
10
+ QWEN_UTILS_AVAILABLE = True
11
+ except ImportError:
12
+ print("Warning: qwen_vl_utils not available, using fallback processing")
13
+ QWEN_UTILS_AVAILABLE = False
14
 
15
  # Global variables to store model and processor
16
  model = None
17
  processor = None
18
  tokenizer = None
19
 
20
+ def process_vision_info_fallback(messages):
21
+ """Fallback function if qwen_vl_utils is not available"""
22
+ image_inputs = []
23
+ video_inputs = []
24
+
25
+ for message in messages:
26
+ if message.get("role") == "user":
27
+ for content in message.get("content", []):
28
+ if content.get("type") == "image":
29
+ image_inputs.append(content["image"])
30
+ elif content.get("type") == "video":
31
+ video_inputs.append(content["video"])
32
+
33
+ return image_inputs, video_inputs
34
+
35
  def load_model():
36
+ """Load the Qwen2.5-VL model and processor with better error handling"""
37
  global model, processor, tokenizer
38
 
39
  if model is None:
40
+ try:
41
+ print("Loading Qwen2.5-VL-7B-Instruct model...")
42
+
43
+ # Try different model loading strategies
44
+ model_id = "Qwen/Qwen2.5-VL-7B-Instruct"
45
+
46
+ # Load processor first (often more stable)
47
+ print("Loading processor...")
48
+ processor = AutoProcessor.from_pretrained(
49
+ model_id,
50
+ trust_remote_code=True
51
+ )
52
+
53
+ # Load tokenizer
54
+ print("Loading tokenizer...")
55
+ tokenizer = AutoTokenizer.from_pretrained(
56
+ model_id,
57
+ trust_remote_code=True
58
+ )
59
+
60
+ # Load model with more conservative settings
61
+ print("Loading model... This may take a few minutes...")
62
+ model = Qwen2VLForConditionalGeneration.from_pretrained(
63
+ model_id,
64
+ torch_dtype=torch.bfloat16,
65
+ device_map="auto",
66
+ trust_remote_code=True,
67
+ # Add these parameters for better compatibility
68
+ attn_implementation="flash_attention_2" if torch.cuda.is_available() else "eager",
69
+ low_cpu_mem_usage=True,
70
+ )
71
+
72
+ print("Model loaded successfully!")
73
+
74
+ except Exception as e:
75
+ print(f"Error loading main model: {e}")
76
+ print("Trying alternative loading method...")
77
+
78
+ try:
79
+ # Fallback: try loading with different parameters
80
+ model = Qwen2VLForConditionalGeneration.from_pretrained(
81
+ model_id,
82
+ torch_dtype=torch.float16, # Try float16 instead
83
+ device_map="cpu", # Force CPU loading
84
+ trust_remote_code=True,
85
+ low_cpu_mem_usage=True,
86
+ )
87
+ print("Model loaded with fallback method!")
88
+
89
+ except Exception as e2:
90
+ print(f"Fallback loading also failed: {e2}")
91
+ print("Trying smaller Qwen2-VL model...")
92
+
93
+ try:
94
+ # Try the older Qwen2-VL model as final fallback
95
+ model_id = "Qwen/Qwen2-VL-7B-Instruct"
96
+ processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
97
+ tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
98
+ model = Qwen2VLForConditionalGeneration.from_pretrained(
99
+ model_id,
100
+ torch_dtype=torch.float16,
101
+ device_map="auto",
102
+ trust_remote_code=True,
103
+ )
104
+ print("Loaded Qwen2-VL (older version) successfully!")
105
+
106
+ except Exception as e3:
107
+ raise Exception(f"All model loading attempts failed. Last error: {e3}")
108
 
109
  return model, processor, tokenizer
110
 
111
  def generate_metadata(image, metadata_type):
112
+ """Generate metadata for the uploaded image with improved error handling"""
113
  if image is None:
114
  return "Please upload an image first."
115
 
 
130
 
131
  prompt = prompts.get(metadata_type, prompts["Basic Description"])
132
 
133
+ # Prepare the conversation format
134
  messages = [
135
  {
136
  "role": "user",
 
144
  }
145
  ]
146
 
147
+ # Process the input with error handling
148
+ try:
149
+ text = processor.apply_chat_template(
150
+ messages, tokenize=False, add_generation_prompt=True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
151
  )
152
+
153
+ # Use appropriate vision processing
154
+ if QWEN_UTILS_AVAILABLE:
155
+ image_inputs, video_inputs = process_vision_info(messages)
156
+ else:
157
+ image_inputs, video_inputs = process_vision_info_fallback(messages)
158
+
159
+ inputs = processor(
160
+ text=[text],
161
+ images=image_inputs,
162
+ videos=video_inputs,
163
+ padding=True,
164
+ return_tensors="pt",
165
+ )
166
+
167
+ # Move to device
168
+ inputs = inputs.to(model.device)
169
+
170
+ except Exception as e:
171
+ print(f"Error in input processing: {e}")
172
+ # Fallback to simpler processing
173
+ try:
174
+ inputs = processor(
175
+ text=prompt,
176
+ images=image,
177
+ return_tensors="pt",
178
+ padding=True
179
+ )
180
+ inputs = inputs.to(model.device)
181
+ except Exception as e2:
182
+ return f"Error processing input: {str(e2)}"
183
 
184
+ # Generate response with conservative parameters
185
+ try:
186
+ with torch.no_grad():
187
+ generated_ids = model.generate(
188
+ **inputs,
189
+ max_new_tokens=384, # Reduced from 512
190
+ temperature=0.7,
191
+ do_sample=True,
192
+ top_p=0.9,
193
+ pad_token_id=tokenizer.pad_token_id,
194
+ eos_token_id=tokenizer.eos_token_id,
195
+ )
196
+
197
+ # Extract and decode the response
198
+ generated_ids_trimmed = [
199
+ out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
200
+ ]
201
+ output_text = processor.batch_decode(
202
+ generated_ids_trimmed,
203
+ skip_special_tokens=True,
204
+ clean_up_tokenization_spaces=False
205
+ )[0]
206
+
207
+ return output_text.strip()
208
+
209
+ except Exception as e:
210
+ return f"Error during generation: {str(e)}"
211
 
212
  except Exception as e:
213
  return f"Error generating metadata: {str(e)}"
 
215
  def create_interface():
216
  """Create the Gradio interface"""
217
 
 
218
  css = """
219
  .metadata-container {
220
  background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
 
275
  elem_classes=["output-text"]
276
  )
277
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
278
  # Event handlers
279
  generate_btn.click(
280
  fn=generate_metadata,
 
283
  show_progress=True
284
  )
285
 
286
+ # Auto-generate on image upload
287
  image_input.change(
288
  fn=lambda img: generate_metadata(img, "Basic Description") if img else "",
289
  inputs=[image_input],
 
294
  gr.HTML("""
295
  <div style="text-align: center; padding: 20px; margin-top: 30px; border-top: 1px solid #eee;">
296
  <p style="color: #666;">
297
+ This Space uses Qwen2.5-VL for intelligent image analysis and metadata generation.
298
  <br>Perfect for content management, SEO optimization, and accessibility improvements.
299
  </p>
300
+ <p style="color: #888; font-size: 0.9em; margin-top: 10px;">
301
+ <strong>Note:</strong> First generation may take 1-2 minutes while the model loads. Subsequent generations will be much faster.
302
+ </p>
303
  </div>
304
  """)
305
 
306
  return interface
307
 
 
308
  def initialize_app():
309
  """Initialize the application"""
310
  print("Starting Image Metadata Generator...")
311
  print("Model will be loaded on first use to save resources.")
312
 
313
+ # Print system info for debugging
314
+ print(f"PyTorch version: {torch.__version__}")
315
+ print(f"CUDA available: {torch.cuda.is_available()}")
316
+ if torch.cuda.is_available():
317
+ print(f"CUDA device: {torch.cuda.get_device_name(0)}")
318
+
319
  interface = create_interface()
320
  return interface
321
 
requirements.txt CHANGED
@@ -7,3 +7,5 @@ qwen-vl-utils
7
  torchvision
8
  numpy
9
  requests
 
 
 
7
  torchvision
8
  numpy
9
  requests
10
+ flash-attn>=2.0.0
11
+ einops