Fred808 commited on
Commit
047f73e
·
verified ·
1 Parent(s): 85a2cee

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +44 -46
app.py CHANGED
@@ -12,18 +12,18 @@ import uvicorn
12
  DEVICE = "cpu" # Use CPU for compatibility
13
  RESIZE_DIM = (512, 512) # Resize images to this resolution
14
  MAX_IMAGE_SIZE = 10 * 1024 * 1024 # 10MB max image size
 
15
 
16
  # ===== FastAPI App =====
17
  app = FastAPI(
18
  title="Florence-2 Image Analysis API",
19
- description="Analyze images using Microsoft's Florence-2 model",
20
  version="1.0.0"
21
  )
22
 
23
  # ===== Request/Response Models =====
24
  class ImageAnalysisRequest(BaseModel):
25
  image_url: HttpUrl
26
- task: str = "<MORE_DETAILED_CAPTION>" # Default task
27
 
28
  class ImageAnalysisResponse(BaseModel):
29
  caption: str
@@ -38,8 +38,9 @@ try:
38
  model = AutoModelForCausalLM.from_pretrained(
39
  MODEL_ID,
40
  trust_remote_code=True,
41
- attn_implementation="eager"
42
- ).to(DEVICE).eval()
 
43
  print("[INFO] Model loaded successfully!")
44
  except Exception as e:
45
  print(f"[ERROR] Failed to load model: {e}")
@@ -75,8 +76,8 @@ def download_image(url: str) -> Image.Image:
75
  except Exception as e:
76
  raise ValueError(f"Failed to process image: {e}")
77
 
78
- def analyze_image(image: Image.Image, task: str = "<MORE_DETAILED_CAPTION>") -> str:
79
- """Analyze image using Florence-2 model"""
80
  if not processor or not model:
81
  raise ValueError("Model not loaded properly")
82
 
@@ -84,9 +85,9 @@ def analyze_image(image: Image.Image, task: str = "<MORE_DETAILED_CAPTION>") ->
84
  # Resize image for faster processing
85
  image = image.resize(RESIZE_DIM, Image.BILINEAR)
86
 
87
- # Prepare inputs
88
  inputs = processor(
89
- text=task,
90
  images=image,
91
  return_tensors="pt"
92
  ).to(DEVICE)
@@ -101,10 +102,14 @@ def analyze_image(image: Image.Image, task: str = "<MORE_DETAILED_CAPTION>") ->
101
  do_sample=False
102
  )
103
 
104
- # Decode
105
  generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
106
- print(f"[DEBUG] Generated text: {generated_text}")
107
- return generated_text.strip()
 
 
 
 
108
 
109
  except Exception as e:
110
  print(f"[ERROR] Exception in analyze_image: {e}")
@@ -117,7 +122,8 @@ async def root():
117
  return {
118
  "message": "Florence-2 Image Analysis API",
119
  "status": "running",
120
- "model_loaded": processor is not None and model is not None
 
121
  }
122
 
123
  @app.get("/health")
@@ -127,39 +133,21 @@ async def health_check():
127
  "status": "healthy" if (processor and model) else "unhealthy",
128
  "model_loaded": processor is not None and model is not None,
129
  "device": DEVICE,
130
- "available_tasks": [
131
- "<MORE_DETAILED_CAPTION>",
132
- "<DETAILED_CAPTION>",
133
- "<CAPTION>",
134
- "<OD>", # Object Detection
135
- "<DENSE_REGION_CAPTION>",
136
- "<REGION_PROPOSAL>"
137
- ]
138
  }
139
 
140
  @app.post("/analyze", response_model=ImageAnalysisResponse)
141
  async def analyze_image_endpoint(request: ImageAnalysisRequest):
142
  """
143
  Analyze an image from a URL using Florence-2 model
144
-
145
- Available tasks:
146
- - <MORE_DETAILED_CAPTION>: Generate detailed image captions
147
- - <DETAILED_CAPTION>: Generate detailed captions
148
- - <CAPTION>: Generate basic captions
149
- - <OD>: Object detection
150
- - <DENSE_REGION_CAPTION>: Dense region captioning
151
- - <REGION_PROPOSAL>: Region proposal
152
  """
153
  try:
154
- # Validate task
155
- valid_tasks = [
156
- "<MORE_DETAILED_CAPTION>", "<DETAILED_CAPTION>", "<CAPTION>",
157
- "<OD>", "<DENSE_REGION_CAPTION>", "<REGION_PROPOSAL>"
158
- ]
159
- if request.task not in valid_tasks:
160
  raise HTTPException(
161
- status_code=400,
162
- detail=f"Invalid task. Available tasks: {valid_tasks}"
163
  )
164
 
165
  # Download and process image
@@ -167,15 +155,17 @@ async def analyze_image_endpoint(request: ImageAnalysisRequest):
167
  image = download_image(request.image_url)
168
  print(f"[INFO] Image downloaded successfully: {image.size}")
169
 
170
- # Analyze image
171
- caption = analyze_image(image, request.task)
172
- print(f"[INFO] Analysis complete: {caption}")
173
 
174
  return ImageAnalysisResponse(
175
  caption=caption,
176
  success=True
177
  )
178
 
 
 
179
  except ValueError as e:
180
  print(f"[ERROR] ValueError: {e}")
181
  return ImageAnalysisResponse(
@@ -185,27 +175,35 @@ async def analyze_image_endpoint(request: ImageAnalysisRequest):
185
  )
186
  except Exception as e:
187
  print(f"[ERROR] Unexpected error: {e}")
188
- raise HTTPException(status_code=500, detail=f"Internal server error: {e}")
 
 
 
 
189
 
190
  @app.get("/analyze")
191
- async def analyze_image_get(image_url: str, task: str = "<MORE_DETAILED_CAPTION>"):
192
  """
193
  GET endpoint for quick image analysis
194
- Usage: /analyze?image_url=https://example.com/image.jpg&task=<MORE_DETAILED_CAPTION>
195
  """
196
- request = ImageAnalysisRequest(image_url=image_url, task=task)
197
- return await analyze_image_endpoint(request)
 
 
 
198
 
199
  # ===== Main Execution =====
200
  if __name__ == "__main__":
201
  port = int(os.getenv("PORT", 7860))
202
  print(f"[INFO] Starting server on port {port}")
203
  print(f"[INFO] Model status: {'Loaded' if (processor and model) else 'Failed to load'}")
 
204
  print(f"[INFO] API Documentation: http://localhost:{port}/docs")
205
 
206
  uvicorn.run(
207
- "app:app",
208
  host="0.0.0.0",
209
  port=port,
210
- reload=False # Set to True for development
211
  )
 
12
  DEVICE = "cpu" # Use CPU for compatibility
13
  RESIZE_DIM = (512, 512) # Resize images to this resolution
14
  MAX_IMAGE_SIZE = 10 * 1024 * 1024 # 10MB max image size
15
+ TASK = "<MORE_DETAILED_CAPTION>" # Hardcoded task
16
 
17
  # ===== FastAPI App =====
18
  app = FastAPI(
19
  title="Florence-2 Image Analysis API",
20
+ description="Analyze images using Microsoft's Florence-2 model with detailed captions",
21
  version="1.0.0"
22
  )
23
 
24
  # ===== Request/Response Models =====
25
  class ImageAnalysisRequest(BaseModel):
26
  image_url: HttpUrl
 
27
 
28
  class ImageAnalysisResponse(BaseModel):
29
  caption: str
 
38
  model = AutoModelForCausalLM.from_pretrained(
39
  MODEL_ID,
40
  trust_remote_code=True,
41
+ torch_dtype=torch.float32,
42
+ device_map="auto"
43
+ ).eval()
44
  print("[INFO] Model loaded successfully!")
45
  except Exception as e:
46
  print(f"[ERROR] Failed to load model: {e}")
 
76
  except Exception as e:
77
  raise ValueError(f"Failed to process image: {e}")
78
 
79
+ def analyze_image(image: Image.Image) -> str:
80
+ """Analyze image using Florence-2 model with hardcoded task"""
81
  if not processor or not model:
82
  raise ValueError("Model not loaded properly")
83
 
 
85
  # Resize image for faster processing
86
  image = image.resize(RESIZE_DIM, Image.BILINEAR)
87
 
88
+ # Prepare inputs with hardcoded task
89
  inputs = processor(
90
+ text=TASK,
91
  images=image,
92
  return_tensors="pt"
93
  ).to(DEVICE)
 
102
  do_sample=False
103
  )
104
 
105
+ # Decode and clean output
106
  generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
107
+
108
+ # Remove the task prompt from the beginning if present
109
+ if generated_text.startswith(TASK):
110
+ generated_text = generated_text[len(TASK):].strip()
111
+
112
+ return generated_text
113
 
114
  except Exception as e:
115
  print(f"[ERROR] Exception in analyze_image: {e}")
 
122
  return {
123
  "message": "Florence-2 Image Analysis API",
124
  "status": "running",
125
+ "model_loaded": processor is not None and model is not None,
126
+ "task": TASK
127
  }
128
 
129
  @app.get("/health")
 
133
  "status": "healthy" if (processor and model) else "unhealthy",
134
  "model_loaded": processor is not None and model is not None,
135
  "device": DEVICE,
136
+ "task": TASK
 
 
 
 
 
 
 
137
  }
138
 
139
  @app.post("/analyze", response_model=ImageAnalysisResponse)
140
  async def analyze_image_endpoint(request: ImageAnalysisRequest):
141
  """
142
  Analyze an image from a URL using Florence-2 model
143
+ Always uses <MORE_DETAILED_CAPTION> task for detailed image descriptions
 
 
 
 
 
 
 
144
  """
145
  try:
146
+ # Validate model is loaded
147
+ if not processor or not model:
 
 
 
 
148
  raise HTTPException(
149
+ status_code=503,
150
+ detail="Model not loaded. Please check server logs."
151
  )
152
 
153
  # Download and process image
 
155
  image = download_image(request.image_url)
156
  print(f"[INFO] Image downloaded successfully: {image.size}")
157
 
158
+ # Analyze image with hardcoded task
159
+ caption = analyze_image(image)
160
+ print(f"[INFO] Analysis complete")
161
 
162
  return ImageAnalysisResponse(
163
  caption=caption,
164
  success=True
165
  )
166
 
167
+ except HTTPException:
168
+ raise
169
  except ValueError as e:
170
  print(f"[ERROR] ValueError: {e}")
171
  return ImageAnalysisResponse(
 
175
  )
176
  except Exception as e:
177
  print(f"[ERROR] Unexpected error: {e}")
178
+ return ImageAnalysisResponse(
179
+ caption="",
180
+ success=False,
181
+ error_message=f"Internal server error: {str(e)}"
182
+ )
183
 
184
  @app.get("/analyze")
185
+ async def analyze_image_get(image_url: str):
186
  """
187
  GET endpoint for quick image analysis
188
+ Usage: /analyze?image_url=https://example.com/image.jpg
189
  """
190
+ try:
191
+ request = ImageAnalysisRequest(image_url=image_url)
192
+ return await analyze_image_endpoint(request)
193
+ except Exception as e:
194
+ raise HTTPException(status_code=400, detail=str(e))
195
 
196
  # ===== Main Execution =====
197
  if __name__ == "__main__":
198
  port = int(os.getenv("PORT", 7860))
199
  print(f"[INFO] Starting server on port {port}")
200
  print(f"[INFO] Model status: {'Loaded' if (processor and model) else 'Failed to load'}")
201
+ print(f"[INFO] Task: {TASK}")
202
  print(f"[INFO] API Documentation: http://localhost:{port}/docs")
203
 
204
  uvicorn.run(
205
+ app,
206
  host="0.0.0.0",
207
  port=port,
208
+ reload=False
209
  )