Fred808 commited on
Commit
3be11fc
·
verified ·
1 Parent(s): e5dd183

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +212 -339
app.py CHANGED
@@ -1,360 +1,233 @@
1
  import os
2
- import sys
3
- import subprocess
4
- import importlib
5
- import requests
6
- from PIL import Image
7
- from io import BytesIO
8
- from fastapi import FastAPI, HTTPException
9
- from pydantic import BaseModel, HttpUrl
10
- from transformers import AutoProcessor, AutoModelForCausalLM
11
- import uvicorn
12
-
13
- # ===== RUNTIME DEPENDENCY ENSURER =====
14
- # Hardcoded torch version to ensure compatibility at startup.
15
- REQUIRED_TORCH_VERSION = os.getenv("REQUIRED_TORCH_VERSION", "2.2.2")
16
-
17
- def ensure_torch_installed(required_version: str = REQUIRED_TORCH_VERSION):
18
- """Ensure the required torch version is installed at runtime.
19
- This will attempt to import torch and compare versions. If missing or different,
20
- it will pip-install the requested version using the running Python executable.
21
-
22
- Note: Installing torch at every start may be slow and may require build artifacts
23
- specific to the platform. This helper uses a simple pip install; if your target
24
- platform requires a special wheel or extra index URL, set up the environment
25
- outside of this script or modify the install command accordingly.
26
- """
27
- try:
28
- import torch as _t
29
- v = getattr(_t, "__version__", "")
30
- # match major.minor.patch prefix
31
- if v and v.startswith(required_version):
32
- print(f"[INFO] torch {v} already installed")
33
- return _t
34
- else:
35
- print(f"[INFO] torch version {v} != {required_version}, will reinstall")
36
- except Exception:
37
- print("[INFO] torch not found, installing now")
38
 
39
- cmd = [sys.executable, "-m", "pip", "install", f"torch=={required_version}"]
40
- print(f"[INFO] Running: {' '.join(cmd)}")
41
- try:
42
- subprocess.check_call(cmd)
43
- importlib.invalidate_caches()
44
- import torch as _t2
45
- print(f"[INFO] Installed torch {_t2.__version__}")
46
- return _t2
47
- except subprocess.CalledProcessError as e:
48
- print(f"[ERROR] pip install failed: {e}")
49
- raise
50
-
51
-
52
- # Ensure torch is available before using the model
53
- torch = ensure_torch_installed()
54
-
55
- # ===== CONFIG =====
56
- DEVICE = "cpu" # Use CPU for compatibility
57
- RESIZE_DIM = (512, 512) # Resize images to this resolution
58
- MAX_IMAGE_SIZE = 10 * 1024 * 1024 # 10MB max image size
59
- TASK = "<MORE_DETAILED_CAPTION>" # Hardcoded task
60
-
61
- # URL template for frame iteration - replace with your actual URL
62
- BASE_URL_TEMPLATE = "https://example.com/frames/frame_{frame}.jpg"
63
- START_FRAME = 1 # Starting frame number
64
- FRAME_PADDING = 6 # Number of digits to pad frame numbers with
65
-
66
- # ===== FastAPI App =====
67
- app = FastAPI(
68
- title="Florence-2 Image Analysis API",
69
- description="Analyze images using Microsoft's Florence-2 model with detailed captions",
70
- version="1.0.0"
71
- )
72
-
73
- # ===== Request/Response Models =====
74
- class ImageAnalysisRequest(BaseModel):
75
- image_url: HttpUrl
76
-
77
- class ImageAnalysisResponse(BaseModel):
78
- caption: str
79
- success: bool
80
- error_message: str = None
81
-
82
- # ===== Load Florence-2 Base Model =====
83
- print("[INFO] Loading Florence-2 model on CPU...")
84
  try:
85
- MODEL_ID = "microsoft/Florence-2-base"
86
-
87
- # Load processor
88
- processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)
89
-
90
- # Load model
91
- model = AutoModelForCausalLM.from_pretrained(
92
- MODEL_ID,
93
- trust_remote_code=True,
94
- torch_dtype=torch.float32,
95
- )
96
-
97
- # Move to device manually
98
- model = model.to(DEVICE)
99
- model.eval()
100
-
101
- print("[INFO] Model loaded successfully!")
102
- except Exception as e:
103
- print(f"[ERROR] Failed to load model: {e}")
104
- processor = None
105
- model = None
106
-
107
- # ===== Helper Functions =====
108
- def download_image(url: str) -> Image.Image:
109
- """Download image from URL and return PIL Image"""
 
 
 
110
  try:
111
- headers = {
112
- 'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36'
113
- }
114
-
115
- response = requests.get(str(url), headers=headers, timeout=30)
116
- response.raise_for_status()
117
-
118
- if len(response.content) > MAX_IMAGE_SIZE:
119
- raise ValueError(f"Image too large: {len(response.content)} bytes")
120
-
121
- content_type = response.headers.get('content-type', '')
122
- if not content_type.startswith('image/'):
123
- raise ValueError(f"URL does not point to an image. Content-Type: {content_type}")
124
-
125
- image = Image.open(BytesIO(response.content)).convert("RGB")
126
- return image
127
-
128
- except requests.exceptions.RequestException as e:
129
- raise ValueError(f"Failed to download image: {e}")
130
- except Exception as e:
131
- raise ValueError(f"Failed to process image: {e}")
132
 
133
 
134
- def iterate_and_analyze(base_url_template: str, start: int = 1, padding: int = 6):
135
- """Iterate over a templated frame URL and analyze images sequentially.
136
 
137
- base_url_template should contain a placeholder `{frame}` which will be replaced by
138
- the zero-padded frame number, for example:
139
- https://example.com/download?course=XYZ&file=frame%3AXYZ%2F{frame}%2Fframe_000{n}.jpg
140
 
141
- The function yields tuples: (frame_number, url, caption_or_error)
142
- Continues until a frame fails to download (e.g., 404 error)
143
- """
144
- if "{frame}" not in base_url_template:
145
- raise ValueError("base_url_template must contain '{frame}' placeholder")
146
-
147
- consecutive_errors = 0
148
- MAX_CONSECUTIVE_ERRORS = 3 # Stop after this many consecutive errors
149
-
150
- i = start
151
- while True:
152
- frame_id = str(i).zfill(padding)
153
- url = base_url_template.format(frame=frame_id)
 
 
 
 
 
 
 
 
 
154
  try:
155
- img = download_image(url)
156
- caption = analyze_image(img)
157
- consecutive_errors = 0 # Reset error counter on success
158
- yield (i, url, {"success": True, "caption": caption})
159
- except requests.exceptions.HTTPError as e:
160
- if e.response.status_code == 404:
161
- print(f"[INFO] No more frames found after frame {i-1}")
162
- break
163
- yield (i, url, {"success": False, "error": str(e)})
164
- consecutive_errors += 1
165
- except Exception as e:
166
- yield (i, url, {"success": False, "error": str(e)})
167
- consecutive_errors += 1
168
-
169
- if consecutive_errors >= MAX_CONSECUTIVE_ERRORS:
170
- print(f"[INFO] Stopping after {MAX_CONSECUTIVE_ERRORS} consecutive errors")
171
- break
172
-
173
- i += 1
174
-
175
- def analyze_image(image: Image.Image) -> str:
176
- """Analyze image using Florence-2 model with hardcoded task"""
177
- if not processor or not model:
178
- raise ValueError("Model not loaded properly")
179
-
180
  try:
181
- print(f"[DEBUG] Input image size: {image.size}")
182
-
183
- # Resize image
184
- image = image.resize(RESIZE_DIM, Image.LANCZOS)
185
-
186
- # Prepare inputs - use the same approach that worked in the test
187
- inputs = processor(
188
- text=TASK,
189
- images=image,
190
- return_tensors="pt",
191
- padding=True
192
- )
193
-
194
- print(f"[DEBUG] Input keys: {list(inputs.keys())}")
195
- print(f"[DEBUG] Input IDs shape: {inputs['input_ids'].shape}")
196
- print(f"[DEBUG] Pixel values shape: {inputs['pixel_values'].shape}")
197
-
198
- # Move to device
199
- inputs = {k: v.to(DEVICE) for k, v in inputs.items()}
200
-
201
- # Generate caption - use the specific Florence-2 generation approach
202
- print("[DEBUG] Starting generation...")
203
- with torch.no_grad():
204
- generated_ids = model.generate(
205
- input_ids=inputs["input_ids"],
206
- pixel_values=inputs["pixel_values"],
207
- max_new_tokens=100,
208
- num_beams=3,
209
- do_sample=False,
210
- early_stopping=True,
211
- no_repeat_ngram_size=3,
212
- length_penalty=1.0,
213
- )
214
 
215
- print("[DEBUG] Generation completed")
216
-
217
- # Decode and clean output
218
- generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
219
- print(f"[DEBUG] Raw generated text: {repr(generated_text)}")
220
-
221
- # Remove the task prompt from the beginning if present
222
- if generated_text.startswith(TASK):
223
- generated_text = generated_text[len(TASK):].strip()
224
-
225
- print(f"[INFO] Final caption: {generated_text}")
226
- return generated_text
227
 
228
- except Exception as e:
229
- print(f"[ERROR] Exception in analyze_image: {e}")
230
- import traceback
231
- print(f"[ERROR] Traceback: {traceback.format_exc()}")
232
- raise ValueError(f"Failed to analyze image: {e}")
233
-
234
- # ===== API Endpoints =====
235
- @app.get("/")
236
- async def root():
237
- """Health check endpoint"""
238
- return {
239
- "message": "Florence-2 Image Analysis API",
240
- "status": "running",
241
- "model_loaded": processor is not None and model is not None,
242
- "task": TASK
243
- }
244
 
245
- @app.get("/health")
246
- async def health_check():
247
- """Detailed health check"""
248
- return {
249
- "status": "healthy" if (processor and model) else "unhealthy",
250
- "model_loaded": processor is not None and model is not None,
251
- "device": DEVICE,
252
- "task": TASK
253
  }
254
 
255
- @app.post("/analyze", response_model=ImageAnalysisResponse)
256
- async def analyze_image_endpoint(request: ImageAnalysisRequest):
257
- """
258
- Analyze an image from a URL using Florence-2 model
259
- Always uses <MORE_DETAILED_CAPTION> task for detailed image descriptions
260
- """
261
- try:
262
- if not processor or not model:
263
- raise HTTPException(
264
- status_code=503,
265
- detail="Model not loaded. Please check server logs."
 
 
 
 
 
 
 
266
  )
267
-
268
- print(f"[INFO] Processing image from: {request.image_url}")
269
- image = download_image(request.image_url)
270
- print(f"[INFO] Image downloaded successfully: {image.size}")
271
-
272
- caption = analyze_image(image)
273
- print(f"[INFO] Analysis complete")
274
-
275
- return ImageAnalysisResponse(
276
- caption=caption,
277
- success=True
278
- )
279
-
280
- except HTTPException:
281
- raise
282
- except ValueError as e:
283
- print(f"[ERROR] ValueError: {e}")
284
- return ImageAnalysisResponse(
285
- caption="",
286
- success=False,
287
- error_message=str(e)
288
- )
289
- except Exception as e:
290
- print(f"[ERROR] Unexpected error: {e}")
291
- return ImageAnalysisResponse(
292
- caption="",
293
- success=False,
294
- error_message=f"Internal server error: {str(e)}"
295
- )
296
-
297
- @app.get("/analyze")
298
- async def analyze_image_get(image_url: str):
299
- """
300
- GET endpoint for quick image analysis
301
- Usage: /analyze?image_url=https://example.com/image.jpg
302
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
303
  try:
304
- request = ImageAnalysisRequest(image_url=image_url)
305
- return await analyze_image_endpoint(request)
306
  except Exception as e:
307
- raise HTTPException(status_code=400, detail=str(e))
 
 
 
 
 
 
 
 
308
 
309
- # ===== Main Execution =====
310
  if __name__ == "__main__":
311
- if not processor or not model:
312
- print("[ERROR] Model failed to load. Cannot proceed with frame analysis.")
313
- sys.exit(1)
314
-
315
- print("[INFO] Starting frame analysis...")
316
- print(f"[INFO] Using URL template: {BASE_URL_TEMPLATE}")
317
- print(f"[INFO] Starting from frame {START_FRAME} with {FRAME_PADDING} digit padding")
318
-
319
- results = []
320
- for frame_num, url, result in iterate_and_analyze(
321
- BASE_URL_TEMPLATE,
322
- start=START_FRAME,
323
- padding=FRAME_PADDING
324
- ):
325
- if result["success"]:
326
- print(f"[SUCCESS] Frame {frame_num}: {result['caption']}")
327
- results.append({
328
- "frame": frame_num,
329
- "url": url,
330
- "caption": result["caption"]
331
- })
332
- else:
333
- print(f"[ERROR] Frame {frame_num}: {result['error']}")
334
- results.append({
335
- "frame": frame_num,
336
- "url": url,
337
- "error": result["error"]
338
- })
339
-
340
- # Save results to a JSON file
341
- import json
342
- output_file = "frame_analysis_results.json"
343
- with open(output_file, "w", encoding="utf-8") as f:
344
- json.dump(results, f, indent=2, ensure_ascii=False)
345
- print(f"[INFO] Results saved to {output_file}")
346
-
347
- # Optional: start the API server after frame analysis
348
- start_server = os.getenv("START_SERVER", "false").lower() == "true"
349
- if start_server:
350
- port = int(os.getenv("PORT", 7860))
351
- print(f"[INFO] Starting server on port {port}")
352
- print(f"[INFO] Task: {TASK}")
353
- print(f"[INFO] API Documentation: http://localhost:{port}/docs")
354
-
355
- uvicorn.run(
356
- app,
357
- host="0.0.0.0",
358
- port=port,
359
- reload=False
360
- )
 
1
  import os
2
+ import re
3
+ import json
4
+ import time
5
+ from typing import Dict, Any, List
6
+ from urllib.parse import urlparse, parse_qs
7
+
8
+ from fastapi import FastAPI, Request, HTTPException
9
+ from fastapi.responses import JSONResponse
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  try:
12
+ from huggingface_hub import HfApi
13
+ HF_AVAILABLE = True
14
+ except Exception:
15
+ HfApi = None
16
+ HF_AVAILABLE = False
17
+
18
+ # Directory to store compiled uploads
19
+ BASE_DIR = os.path.dirname(__file__)
20
+ UPLOAD_DIR = os.path.join(BASE_DIR, "uploads")
21
+ os.makedirs(UPLOAD_DIR, exist_ok=True)
22
+
23
+ app = FastAPI(title="Data Collection Server", description="Receives text/URLs from captioning/image servers, groups by course, compiles JSON and optionally uploads to HuggingFace.")
24
+
25
+ # In-memory store for course data
26
+ courses: Dict[str, Dict[str, Any]] = {}
27
+
28
+ URL_RE = re.compile(r"https?://[\w\-\./?%&=:@,+~#]+")
29
+ DONE_RE = re.compile(r"\b(done|finished|completed|complete)\b", re.IGNORECASE)
30
+
31
+ HF_TOKEN = os.getenv("HF_TOKEN")
32
+ HF_DATASET_REPO = os.getenv("HF_DATASET_REPO") # e.g. "username/dataset-name"
33
+
34
+
35
+ def extract_urls(text: str) -> List[str]:
36
+ return URL_RE.findall(text or "")
37
+
38
+
39
+ def extract_course_from_url(url: str) -> str:
40
  try:
41
+ parsed = urlparse(url)
42
+ qs = parse_qs(parsed.query)
43
+ course = qs.get("course") or qs.get("Course") or qs.get("COURSE")
44
+ if course:
45
+ return course[0]
46
+ except Exception:
47
+ pass
48
+ return None
 
 
 
 
 
 
 
 
 
 
 
 
 
49
 
50
 
51
+ def now_ts() -> str:
52
+ return time.strftime("%Y%m%dT%H%M%S")
53
 
 
 
 
54
 
55
+ async def parse_request(request: Request) -> Dict[str, Any]:
56
+ """Read incoming request in any format and return a dict with keys: text, json, form, headers"""
57
+ payload = {"text": "", "json": None, "form": {}, "headers": dict(request.headers)}
58
+
59
+ # Try JSON
60
+ try:
61
+ body = await request.json()
62
+ payload["json"] = body
63
+ # if it's a simple string payload inside JSON
64
+ if isinstance(body, str):
65
+ payload["text"] = body
66
+ elif isinstance(body, dict):
67
+ # flatten likely fields
68
+ for k in ["text", "caption", "message", "body", "content"]:
69
+ if k in body and isinstance(body[k], str):
70
+ payload["text"] = body[k]
71
+ break
72
+ # allow explicit course field
73
+ if "course" in body and isinstance(body["course"], str):
74
+ payload["course"] = body["course"]
75
+ except Exception:
76
+ # not JSON - try raw body
77
  try:
78
+ raw = (await request.body()).decode("utf-8", errors="ignore")
79
+ payload["text"] = raw
80
+ except Exception:
81
+ payload["text"] = ""
82
+
83
+ # Try form (for multipart/form-data)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
  try:
85
+ form = await request.form()
86
+ for k, v in form.multi_items():
87
+ # take first text-like value
88
+ payload["form"][k] = str(v)
89
+ if k in ("text", "caption", "message", "content") and not payload["text"]:
90
+ payload["text"] = str(v)
91
+ if k == "course":
92
+ payload["course"] = str(v)
93
+ except Exception:
94
+ pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
 
96
+ # If no text yet but JSON is a list or similar, stringify (best-effort)
97
+ if not payload["text"] and payload.get("json") is not None:
98
+ try:
99
+ payload["text"] = json.dumps(payload["json"])
100
+ except Exception:
101
+ payload["text"] = str(payload["json"])
 
 
 
 
 
 
102
 
103
+ return payload
104
+
105
+
106
+ def add_entry(course: str, entry: Dict[str, Any]):
107
+ c = courses.setdefault(course, {"items": [], "last_updated": None})
108
+ c["items"].append(entry)
109
+ c["last_updated"] = time.time()
110
+
111
+
112
+ def compile_course(course: str) -> str:
113
+ """Compile course data to JSON file and optionally upload to HuggingFace. Returns path to saved file."""
114
+ if course not in courses:
115
+ raise ValueError(f"Unknown course: {course}")
 
 
 
116
 
117
+ data = {
118
+ "course": course,
119
+ "compiled_at": now_ts(),
120
+ "count": len(courses[course]["items"]),
121
+ "items": courses[course]["items"],
 
 
 
122
  }
123
 
124
+ filename = f"{course}_{now_ts()}.json"
125
+ safe_filename = re.sub(r"[^a-zA-Z0-9_\-\.]+", "_", filename)
126
+ path = os.path.join(UPLOAD_DIR, safe_filename)
127
+
128
+ with open(path, "w", encoding="utf-8") as f:
129
+ json.dump(data, f, ensure_ascii=False, indent=2)
130
+
131
+ # Optionally upload to HuggingFace
132
+ if HF_TOKEN and HF_DATASET_REPO and HF_AVAILABLE:
133
+ try:
134
+ api = HfApi()
135
+ # upload path at root of repo with same filename
136
+ api.upload_file(
137
+ path_or_fileobj=path,
138
+ path_in_repo=safe_filename,
139
+ repo_id=HF_DATASET_REPO,
140
+ repo_type="dataset",
141
+ token=HF_TOKEN,
142
  )
143
+ except Exception as e:
144
+ # Log but don't fail the compile
145
+ print(f"[WARN] HuggingFace upload failed: {e}")
146
+
147
+ # After compiling, clear stored items for that course
148
+ courses[course]["items"] = []
149
+ return path
150
+
151
+
152
+ @app.post("/submit")
153
+ async def submit(request: Request):
154
+ """Receive any data (text, JSON, form). Will try to extract course and URLs and store entries.
155
+ If message contains 'done' or similar, it will compile the course to JSON (and upload if configured).
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
156
  """
157
+ payload = await parse_request(request)
158
+ text = (payload.get("text") or "").strip()
159
+
160
+ # Collect urls found
161
+ urls = extract_urls(text)
162
+
163
+ # Determine course from payload (explicit field) or from any URL
164
+ course = payload.get("course")
165
+ if not course:
166
+ for u in urls:
167
+ c = extract_course_from_url(u)
168
+ if c:
169
+ course = c
170
+ break
171
+
172
+ if not course:
173
+ course = "unknown_course"
174
+
175
+ entry = {
176
+ "timestamp": now_ts(),
177
+ "text": text,
178
+ "json": payload.get("json"),
179
+ "form": payload.get("form"),
180
+ "urls": urls,
181
+ "headers": {k: v for k, v in payload.get("headers", {}).items() if k.lower() in ("user-agent", "host", "content-type")},
182
+ }
183
+
184
+ add_entry(course, entry)
185
+
186
+ # Detect completion
187
+ if DONE_RE.search(text):
188
+ try:
189
+ path = compile_course(course)
190
+ return JSONResponse({"status": "compiled", "course": course, "path": path})
191
+ except Exception as e:
192
+ raise HTTPException(status_code=500, detail=str(e))
193
+
194
+ # Detect explicit 'course change' in URLs (if a URL contains a different course than stored) -- best-effort
195
+ # If a URL indicates a different course and there were previous items, compile previous course first
196
+ # Example: previous stored course is same; we don't track per-source last course, so skip this more complex behavior for now
197
+
198
+ return JSONResponse({"status": "stored", "course": course, "count": len(courses[course]["items"])})
199
+
200
+
201
+ @app.get("/status")
202
+ async def status():
203
+ summary = {c: {"count": len(v["items"]), "last_updated": v["last_updated"]} for c, v in courses.items()}
204
+ return {"courses": summary}
205
+
206
+
207
+ @app.post("/compile")
208
+ async def compile_endpoint(course: str = None):
209
+ """Force compile a course. If course is not provided and only one exists, compile that one."""
210
+ if not course:
211
+ if len(courses) == 1:
212
+ course = next(iter(courses.keys()))
213
+ else:
214
+ raise HTTPException(status_code=400, detail="Provide course query parameter when multiple courses exist.")
215
+
216
  try:
217
+ path = compile_course(course)
218
+ return {"status": "compiled", "course": course, "path": path}
219
  except Exception as e:
220
+ raise HTTPException(status_code=500, detail=str(e))
221
+
222
+
223
+ @app.get("/debug/{course}")
224
+ async def debug_course(course: str):
225
+ if course not in courses:
226
+ raise HTTPException(status_code=404, detail="Course not found")
227
+ return courses[course]
228
+
229
 
 
230
  if __name__ == "__main__":
231
+ import uvicorn
232
+ port = int(os.getenv("PORT", "8000"))
233
+ uvicorn.run(app, host="0.0.0.0", port=port)