Fred808 commited on
Commit
bf125da
·
verified ·
1 Parent(s): f4f8231

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -68
app.py CHANGED
@@ -33,17 +33,16 @@ class ImageAnalysisResponse(BaseModel):
33
  # ===== Load Florence-2 Base Model =====
34
  print("[INFO] Loading Florence-2 model on CPU...")
35
  try:
36
- MODEL_ID = "microsoft/Florence-2-base" # Using base for better compatibility
37
 
38
  # Load processor
39
  processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)
40
 
41
- # Load model with specific parameters to avoid SDPA issues
42
  model = AutoModelForCausalLM.from_pretrained(
43
  MODEL_ID,
44
  trust_remote_code=True,
45
  torch_dtype=torch.float32,
46
- attn_implementation="eager", # Force eager attention to avoid SDPA issues
47
  )
48
 
49
  # Move to device manually
@@ -60,19 +59,16 @@ except Exception as e:
60
  def download_image(url: str) -> Image.Image:
61
  """Download image from URL and return PIL Image"""
62
  try:
63
- # Set headers to mimic browser request
64
  headers = {
65
- 'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36'
66
  }
67
 
68
  response = requests.get(str(url), headers=headers, timeout=30)
69
  response.raise_for_status()
70
 
71
- # Check content length
72
  if len(response.content) > MAX_IMAGE_SIZE:
73
- raise ValueError(f"Image too large: {len(response.content)} bytes (max: {MAX_IMAGE_SIZE})")
74
 
75
- # Check if content is actually an image
76
  content_type = response.headers.get('content-type', '')
77
  if not content_type.startswith('image/'):
78
  raise ValueError(f"URL does not point to an image. Content-Type: {content_type}")
@@ -91,56 +87,38 @@ def analyze_image(image: Image.Image) -> str:
91
  raise ValueError("Model not loaded properly")
92
 
93
  try:
94
- print(f"[DEBUG] Input image size: {image.size}, mode: {image.mode}")
95
 
96
- # Resize image for faster processing
97
- original_size = image.size
98
  image = image.resize(RESIZE_DIM, Image.LANCZOS)
99
- print(f"[DEBUG] Resized image: {original_size} -> {image.size}")
100
 
101
- # Prepare inputs with explicit attention mask handling
102
- print(f"[DEBUG] Processing image with task: {TASK}")
103
-
104
  inputs = processor(
105
  text=TASK,
106
  images=image,
107
  return_tensors="pt",
108
- padding=True,
109
- truncation=True
110
  )
111
 
112
  print(f"[DEBUG] Input keys: {list(inputs.keys())}")
113
- print(f"[DEBUG] Pixel values type: {type(inputs.get('pixel_values'))}")
114
- if inputs.get('pixel_values') is not None:
115
- print(f"[DEBUG] Pixel values shape: {inputs['pixel_values'].shape}")
116
- else:
117
- print("[DEBUG] Pixel values is None!")
118
- raise ValueError("Pixel values are None - image processing failed")
119
 
120
  # Move to device
121
- inputs = {k: v.to(DEVICE) if hasattr(v, 'to') else v for k, v in inputs.items()}
122
-
123
- # Ensure attention mask is set
124
- if 'attention_mask' not in inputs:
125
- inputs['attention_mask'] = torch.ones_like(inputs['input_ids'])
126
 
127
- print(f"[DEBUG] Input IDs shape: {inputs['input_ids'].shape}")
128
- print(f"[DEBUG] Attention mask shape: {inputs['attention_mask'].shape}")
129
- print(f"[DEBUG] Pixel values device: {inputs['pixel_values'].device}")
130
-
131
- # Generate caption with error handling
132
  print("[DEBUG] Starting generation...")
133
  with torch.no_grad():
134
  generated_ids = model.generate(
135
  input_ids=inputs["input_ids"],
136
- attention_mask=inputs["attention_mask"],
137
  pixel_values=inputs["pixel_values"],
138
- max_new_tokens=128, # Reduced for stability
139
- num_beams=2, # Reduced for CPU
140
  do_sample=False,
141
  early_stopping=True,
142
- pad_token_id=processor.tokenizer.pad_token_id or processor.tokenizer.eos_token_id,
143
- eos_token_id=processor.tokenizer.eos_token_id
144
  )
145
 
146
  print("[DEBUG] Generation completed")
@@ -190,19 +168,16 @@ async def analyze_image_endpoint(request: ImageAnalysisRequest):
190
  Always uses <MORE_DETAILED_CAPTION> task for detailed image descriptions
191
  """
192
  try:
193
- # Validate model is loaded
194
  if not processor or not model:
195
  raise HTTPException(
196
  status_code=503,
197
  detail="Model not loaded. Please check server logs."
198
  )
199
 
200
- # Download and process image
201
  print(f"[INFO] Processing image from: {request.image_url}")
202
  image = download_image(request.image_url)
203
  print(f"[INFO] Image downloaded successfully: {image.size}")
204
 
205
- # Analyze image with hardcoded task
206
  caption = analyze_image(image)
207
  print(f"[INFO] Analysis complete")
208
 
@@ -222,8 +197,6 @@ async def analyze_image_endpoint(request: ImageAnalysisRequest):
222
  )
223
  except Exception as e:
224
  print(f"[ERROR] Unexpected error: {e}")
225
- import traceback
226
- print(f"[ERROR] Traceback: {traceback.format_exc()}")
227
  return ImageAnalysisResponse(
228
  caption="",
229
  success=False,
@@ -242,31 +215,6 @@ async def analyze_image_get(image_url: str):
242
  except Exception as e:
243
  raise HTTPException(status_code=400, detail=str(e))
244
 
245
- # ===== Test Endpoint =====
246
- @app.post("/test-processor")
247
- async def test_processor(request: ImageAnalysisRequest):
248
- """Test endpoint to debug the processor without full model inference"""
249
- try:
250
- image = download_image(request.image_url)
251
- print(f"[TEST] Image downloaded: {image.size}")
252
-
253
- # Test just the processor
254
- inputs = processor(
255
- text=TASK,
256
- images=image,
257
- return_tensors="pt"
258
- )
259
-
260
- return {
261
- "success": True,
262
- "input_keys": list(inputs.keys()),
263
- "input_ids_shape": inputs["input_ids"].shape if "input_ids" in inputs else None,
264
- "pixel_values_shape": inputs["pixel_values"].shape if "pixel_values" in inputs else None,
265
- "pixel_values_type": str(inputs["pixel_values"].dtype) if "pixel_values" in inputs else None
266
- }
267
- except Exception as e:
268
- return {"success": False, "error": str(e)}
269
-
270
  # ===== Main Execution =====
271
  if __name__ == "__main__":
272
  port = int(os.getenv("PORT", 7860))
 
33
  # ===== Load Florence-2 Base Model =====
34
  print("[INFO] Loading Florence-2 model on CPU...")
35
  try:
36
+ MODEL_ID = "microsoft/Florence-2-base"
37
 
38
  # Load processor
39
  processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)
40
 
41
+ # Load model
42
  model = AutoModelForCausalLM.from_pretrained(
43
  MODEL_ID,
44
  trust_remote_code=True,
45
  torch_dtype=torch.float32,
 
46
  )
47
 
48
  # Move to device manually
 
59
  def download_image(url: str) -> Image.Image:
60
  """Download image from URL and return PIL Image"""
61
  try:
 
62
  headers = {
63
+ 'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36'
64
  }
65
 
66
  response = requests.get(str(url), headers=headers, timeout=30)
67
  response.raise_for_status()
68
 
 
69
  if len(response.content) > MAX_IMAGE_SIZE:
70
+ raise ValueError(f"Image too large: {len(response.content)} bytes")
71
 
 
72
  content_type = response.headers.get('content-type', '')
73
  if not content_type.startswith('image/'):
74
  raise ValueError(f"URL does not point to an image. Content-Type: {content_type}")
 
87
  raise ValueError("Model not loaded properly")
88
 
89
  try:
90
+ print(f"[DEBUG] Input image size: {image.size}")
91
 
92
+ # Resize image
 
93
  image = image.resize(RESIZE_DIM, Image.LANCZOS)
 
94
 
95
+ # Prepare inputs - use the same approach that worked in the test
 
 
96
  inputs = processor(
97
  text=TASK,
98
  images=image,
99
  return_tensors="pt",
100
+ padding=True
 
101
  )
102
 
103
  print(f"[DEBUG] Input keys: {list(inputs.keys())}")
104
+ print(f"[DEBUG] Input IDs shape: {inputs['input_ids'].shape}")
105
+ print(f"[DEBUG] Pixel values shape: {inputs['pixel_values'].shape}")
 
 
 
 
106
 
107
  # Move to device
108
+ inputs = {k: v.to(DEVICE) for k, v in inputs.items()}
 
 
 
 
109
 
110
+ # Generate caption - use the specific Florence-2 generation approach
 
 
 
 
111
  print("[DEBUG] Starting generation...")
112
  with torch.no_grad():
113
  generated_ids = model.generate(
114
  input_ids=inputs["input_ids"],
 
115
  pixel_values=inputs["pixel_values"],
116
+ max_new_tokens=100,
117
+ num_beams=3,
118
  do_sample=False,
119
  early_stopping=True,
120
+ no_repeat_ngram_size=3,
121
+ length_penalty=1.0,
122
  )
123
 
124
  print("[DEBUG] Generation completed")
 
168
  Always uses <MORE_DETAILED_CAPTION> task for detailed image descriptions
169
  """
170
  try:
 
171
  if not processor or not model:
172
  raise HTTPException(
173
  status_code=503,
174
  detail="Model not loaded. Please check server logs."
175
  )
176
 
 
177
  print(f"[INFO] Processing image from: {request.image_url}")
178
  image = download_image(request.image_url)
179
  print(f"[INFO] Image downloaded successfully: {image.size}")
180
 
 
181
  caption = analyze_image(image)
182
  print(f"[INFO] Analysis complete")
183
 
 
197
  )
198
  except Exception as e:
199
  print(f"[ERROR] Unexpected error: {e}")
 
 
200
  return ImageAnalysisResponse(
201
  caption="",
202
  success=False,
 
215
  except Exception as e:
216
  raise HTTPException(status_code=400, detail=str(e))
217
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
218
  # ===== Main Execution =====
219
  if __name__ == "__main__":
220
  port = int(os.getenv("PORT", 7860))