Fred808 commited on
Commit
f4f8231
·
verified ·
1 Parent(s): 05ab361

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +73 -26
app.py CHANGED
@@ -33,7 +33,7 @@ class ImageAnalysisResponse(BaseModel):
33
  # ===== Load Florence-2 Base Model =====
34
  print("[INFO] Loading Florence-2 model on CPU...")
35
  try:
36
- MODEL_ID = "microsoft/Florence-2-large"
37
 
38
  # Load processor
39
  processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)
@@ -44,7 +44,6 @@ try:
44
  trust_remote_code=True,
45
  torch_dtype=torch.float32,
46
  attn_implementation="eager", # Force eager attention to avoid SDPA issues
47
- device_map=None # Explicitly set to None for CPU
48
  )
49
 
50
  # Move to device manually
@@ -54,23 +53,8 @@ try:
54
  print("[INFO] Model loaded successfully!")
55
  except Exception as e:
56
  print(f"[ERROR] Failed to load model: {e}")
57
- # Try fallback to base model if large fails
58
- try:
59
- print("[INFO] Trying Florence-2-base as fallback...")
60
- MODEL_ID = "microsoft/Florence-2-base"
61
- processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)
62
- model = AutoModelForCausalLM.from_pretrained(
63
- MODEL_ID,
64
- trust_remote_code=True,
65
- torch_dtype=torch.float32,
66
- attn_implementation="eager",
67
- device_map=None
68
- ).to(DEVICE).eval()
69
- print("[INFO] Fallback model loaded successfully!")
70
- except Exception as fallback_error:
71
- print(f"[ERROR] Fallback also failed: {fallback_error}")
72
- processor = None
73
- model = None
74
 
75
  # ===== Helper Functions =====
76
  def download_image(url: str) -> Image.Image:
@@ -107,39 +91,75 @@ def analyze_image(image: Image.Image) -> str:
107
  raise ValueError("Model not loaded properly")
108
 
109
  try:
 
 
110
  # Resize image for faster processing
111
- image = image.resize(RESIZE_DIM, Image.BILINEAR)
 
 
112
 
113
- # Prepare inputs with hardcoded task
 
 
114
  inputs = processor(
115
  text=TASK,
116
  images=image,
117
- return_tensors="pt"
118
- ).to(DEVICE)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
119
 
120
  # Generate caption with error handling
 
121
  with torch.no_grad():
122
  generated_ids = model.generate(
123
  input_ids=inputs["input_ids"],
 
124
  pixel_values=inputs["pixel_values"],
125
- max_new_tokens=256, # Reduced for stability
126
- num_beams=3,
127
  do_sample=False,
128
  early_stopping=True,
129
- pad_token_id=processor.tokenizer.eos_token_id
 
130
  )
131
 
 
 
132
  # Decode and clean output
133
  generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
 
134
 
135
  # Remove the task prompt from the beginning if present
136
  if generated_text.startswith(TASK):
137
  generated_text = generated_text[len(TASK):].strip()
138
 
 
139
  return generated_text
140
 
141
  except Exception as e:
142
  print(f"[ERROR] Exception in analyze_image: {e}")
 
 
143
  raise ValueError(f"Failed to analyze image: {e}")
144
 
145
  # ===== API Endpoints =====
@@ -202,6 +222,8 @@ async def analyze_image_endpoint(request: ImageAnalysisRequest):
202
  )
203
  except Exception as e:
204
  print(f"[ERROR] Unexpected error: {e}")
 
 
205
  return ImageAnalysisResponse(
206
  caption="",
207
  success=False,
@@ -220,6 +242,31 @@ async def analyze_image_get(image_url: str):
220
  except Exception as e:
221
  raise HTTPException(status_code=400, detail=str(e))
222
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
223
  # ===== Main Execution =====
224
  if __name__ == "__main__":
225
  port = int(os.getenv("PORT", 7860))
 
33
  # ===== Load Florence-2 Base Model =====
34
  print("[INFO] Loading Florence-2 model on CPU...")
35
  try:
36
+ MODEL_ID = "microsoft/Florence-2-base" # Using base for better compatibility
37
 
38
  # Load processor
39
  processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)
 
44
  trust_remote_code=True,
45
  torch_dtype=torch.float32,
46
  attn_implementation="eager", # Force eager attention to avoid SDPA issues
 
47
  )
48
 
49
  # Move to device manually
 
53
  print("[INFO] Model loaded successfully!")
54
  except Exception as e:
55
  print(f"[ERROR] Failed to load model: {e}")
56
+ processor = None
57
+ model = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
 
59
  # ===== Helper Functions =====
60
  def download_image(url: str) -> Image.Image:
 
91
  raise ValueError("Model not loaded properly")
92
 
93
  try:
94
+ print(f"[DEBUG] Input image size: {image.size}, mode: {image.mode}")
95
+
96
  # Resize image for faster processing
97
+ original_size = image.size
98
+ image = image.resize(RESIZE_DIM, Image.LANCZOS)
99
+ print(f"[DEBUG] Resized image: {original_size} -> {image.size}")
100
 
101
+ # Prepare inputs with explicit attention mask handling
102
+ print(f"[DEBUG] Processing image with task: {TASK}")
103
+
104
  inputs = processor(
105
  text=TASK,
106
  images=image,
107
+ return_tensors="pt",
108
+ padding=True,
109
+ truncation=True
110
+ )
111
+
112
+ print(f"[DEBUG] Input keys: {list(inputs.keys())}")
113
+ print(f"[DEBUG] Pixel values type: {type(inputs.get('pixel_values'))}")
114
+ if inputs.get('pixel_values') is not None:
115
+ print(f"[DEBUG] Pixel values shape: {inputs['pixel_values'].shape}")
116
+ else:
117
+ print("[DEBUG] Pixel values is None!")
118
+ raise ValueError("Pixel values are None - image processing failed")
119
+
120
+ # Move to device
121
+ inputs = {k: v.to(DEVICE) if hasattr(v, 'to') else v for k, v in inputs.items()}
122
+
123
+ # Ensure attention mask is set
124
+ if 'attention_mask' not in inputs:
125
+ inputs['attention_mask'] = torch.ones_like(inputs['input_ids'])
126
+
127
+ print(f"[DEBUG] Input IDs shape: {inputs['input_ids'].shape}")
128
+ print(f"[DEBUG] Attention mask shape: {inputs['attention_mask'].shape}")
129
+ print(f"[DEBUG] Pixel values device: {inputs['pixel_values'].device}")
130
 
131
  # Generate caption with error handling
132
+ print("[DEBUG] Starting generation...")
133
  with torch.no_grad():
134
  generated_ids = model.generate(
135
  input_ids=inputs["input_ids"],
136
+ attention_mask=inputs["attention_mask"],
137
  pixel_values=inputs["pixel_values"],
138
+ max_new_tokens=128, # Reduced for stability
139
+ num_beams=2, # Reduced for CPU
140
  do_sample=False,
141
  early_stopping=True,
142
+ pad_token_id=processor.tokenizer.pad_token_id or processor.tokenizer.eos_token_id,
143
+ eos_token_id=processor.tokenizer.eos_token_id
144
  )
145
 
146
+ print("[DEBUG] Generation completed")
147
+
148
  # Decode and clean output
149
  generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
150
+ print(f"[DEBUG] Raw generated text: {repr(generated_text)}")
151
 
152
  # Remove the task prompt from the beginning if present
153
  if generated_text.startswith(TASK):
154
  generated_text = generated_text[len(TASK):].strip()
155
 
156
+ print(f"[INFO] Final caption: {generated_text}")
157
  return generated_text
158
 
159
  except Exception as e:
160
  print(f"[ERROR] Exception in analyze_image: {e}")
161
+ import traceback
162
+ print(f"[ERROR] Traceback: {traceback.format_exc()}")
163
  raise ValueError(f"Failed to analyze image: {e}")
164
 
165
  # ===== API Endpoints =====
 
222
  )
223
  except Exception as e:
224
  print(f"[ERROR] Unexpected error: {e}")
225
+ import traceback
226
+ print(f"[ERROR] Traceback: {traceback.format_exc()}")
227
  return ImageAnalysisResponse(
228
  caption="",
229
  success=False,
 
242
  except Exception as e:
243
  raise HTTPException(status_code=400, detail=str(e))
244
 
245
+ # ===== Test Endpoint =====
246
+ @app.post("/test-processor")
247
+ async def test_processor(request: ImageAnalysisRequest):
248
+ """Test endpoint to debug the processor without full model inference"""
249
+ try:
250
+ image = download_image(request.image_url)
251
+ print(f"[TEST] Image downloaded: {image.size}")
252
+
253
+ # Test just the processor
254
+ inputs = processor(
255
+ text=TASK,
256
+ images=image,
257
+ return_tensors="pt"
258
+ )
259
+
260
+ return {
261
+ "success": True,
262
+ "input_keys": list(inputs.keys()),
263
+ "input_ids_shape": inputs["input_ids"].shape if "input_ids" in inputs else None,
264
+ "pixel_values_shape": inputs["pixel_values"].shape if "pixel_values" in inputs else None,
265
+ "pixel_values_type": str(inputs["pixel_values"].dtype) if "pixel_values" in inputs else None
266
+ }
267
+ except Exception as e:
268
+ return {"success": False, "error": str(e)}
269
+
270
  # ===== Main Execution =====
271
  if __name__ == "__main__":
272
  port = int(os.getenv("PORT", 7860))