Update app.py
Browse files
app.py
CHANGED
|
@@ -1,5 +1,7 @@
|
|
| 1 |
import os
|
| 2 |
-
import
|
|
|
|
|
|
|
| 3 |
import requests
|
| 4 |
from PIL import Image
|
| 5 |
from io import BytesIO
|
|
@@ -8,12 +10,59 @@ from pydantic import BaseModel, HttpUrl
|
|
| 8 |
from transformers import AutoProcessor, AutoModelForCausalLM
|
| 9 |
import uvicorn
|
| 10 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
# ===== CONFIG =====
|
| 12 |
DEVICE = "cpu" # Use CPU for compatibility
|
| 13 |
RESIZE_DIM = (512, 512) # Resize images to this resolution
|
| 14 |
MAX_IMAGE_SIZE = 10 * 1024 * 1024 # 10MB max image size
|
| 15 |
TASK = "<MORE_DETAILED_CAPTION>" # Hardcoded task
|
| 16 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
# ===== FastAPI App =====
|
| 18 |
app = FastAPI(
|
| 19 |
title="Florence-2 Image Analysis API",
|
|
@@ -81,6 +130,48 @@ def download_image(url: str) -> Image.Image:
|
|
| 81 |
except Exception as e:
|
| 82 |
raise ValueError(f"Failed to process image: {e}")
|
| 83 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 84 |
def analyze_image(image: Image.Image) -> str:
|
| 85 |
"""Analyze image using Florence-2 model with hardcoded task"""
|
| 86 |
if not processor or not model:
|
|
@@ -217,15 +308,53 @@ async def analyze_image_get(image_url: str):
|
|
| 217 |
|
| 218 |
# ===== Main Execution =====
|
| 219 |
if __name__ == "__main__":
|
| 220 |
-
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
|
| 224 |
-
print(
|
| 225 |
-
|
| 226 |
-
|
| 227 |
-
|
| 228 |
-
|
| 229 |
-
|
| 230 |
-
|
| 231 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import os
|
| 2 |
+
import sys
|
| 3 |
+
import subprocess
|
| 4 |
+
import importlib
|
| 5 |
import requests
|
| 6 |
from PIL import Image
|
| 7 |
from io import BytesIO
|
|
|
|
| 10 |
from transformers import AutoProcessor, AutoModelForCausalLM
|
| 11 |
import uvicorn
|
| 12 |
|
| 13 |
+
# ===== RUNTIME DEPENDENCY ENSURER =====
|
| 14 |
+
# Hardcoded torch version to ensure compatibility at startup.
|
| 15 |
+
REQUIRED_TORCH_VERSION = os.getenv("REQUIRED_TORCH_VERSION", "2.2.2")
|
| 16 |
+
|
| 17 |
+
def ensure_torch_installed(required_version: str = REQUIRED_TORCH_VERSION):
|
| 18 |
+
"""Ensure the required torch version is installed at runtime.
|
| 19 |
+
This will attempt to import torch and compare versions. If missing or different,
|
| 20 |
+
it will pip-install the requested version using the running Python executable.
|
| 21 |
+
|
| 22 |
+
Note: Installing torch at every start may be slow and may require build artifacts
|
| 23 |
+
specific to the platform. This helper uses a simple pip install; if your target
|
| 24 |
+
platform requires a special wheel or extra index URL, set up the environment
|
| 25 |
+
outside of this script or modify the install command accordingly.
|
| 26 |
+
"""
|
| 27 |
+
try:
|
| 28 |
+
import torch as _t
|
| 29 |
+
v = getattr(_t, "__version__", "")
|
| 30 |
+
# match major.minor.patch prefix
|
| 31 |
+
if v and v.startswith(required_version):
|
| 32 |
+
print(f"[INFO] torch {v} already installed")
|
| 33 |
+
return _t
|
| 34 |
+
else:
|
| 35 |
+
print(f"[INFO] torch version {v} != {required_version}, will reinstall")
|
| 36 |
+
except Exception:
|
| 37 |
+
print("[INFO] torch not found, installing now")
|
| 38 |
+
|
| 39 |
+
cmd = [sys.executable, "-m", "pip", "install", f"torch=={required_version}"]
|
| 40 |
+
print(f"[INFO] Running: {' '.join(cmd)}")
|
| 41 |
+
try:
|
| 42 |
+
subprocess.check_call(cmd)
|
| 43 |
+
importlib.invalidate_caches()
|
| 44 |
+
import torch as _t2
|
| 45 |
+
print(f"[INFO] Installed torch {_t2.__version__}")
|
| 46 |
+
return _t2
|
| 47 |
+
except subprocess.CalledProcessError as e:
|
| 48 |
+
print(f"[ERROR] pip install failed: {e}")
|
| 49 |
+
raise
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
# Ensure torch is available before using the model
|
| 53 |
+
torch = ensure_torch_installed()
|
| 54 |
+
|
| 55 |
# ===== CONFIG =====
|
| 56 |
DEVICE = "cpu" # Use CPU for compatibility
|
| 57 |
RESIZE_DIM = (512, 512) # Resize images to this resolution
|
| 58 |
MAX_IMAGE_SIZE = 10 * 1024 * 1024 # 10MB max image size
|
| 59 |
TASK = "<MORE_DETAILED_CAPTION>" # Hardcoded task
|
| 60 |
|
| 61 |
+
# URL template for frame iteration - replace with your actual URL
|
| 62 |
+
BASE_URL_TEMPLATE = "https://example.com/frames/frame_{frame}.jpg"
|
| 63 |
+
START_FRAME = 1 # Starting frame number
|
| 64 |
+
FRAME_PADDING = 6 # Number of digits to pad frame numbers with
|
| 65 |
+
|
| 66 |
# ===== FastAPI App =====
|
| 67 |
app = FastAPI(
|
| 68 |
title="Florence-2 Image Analysis API",
|
|
|
|
| 130 |
except Exception as e:
|
| 131 |
raise ValueError(f"Failed to process image: {e}")
|
| 132 |
|
| 133 |
+
|
| 134 |
+
def iterate_and_analyze(base_url_template: str, start: int = 1, padding: int = 6):
|
| 135 |
+
"""Iterate over a templated frame URL and analyze images sequentially.
|
| 136 |
+
|
| 137 |
+
base_url_template should contain a placeholder `{frame}` which will be replaced by
|
| 138 |
+
the zero-padded frame number, for example:
|
| 139 |
+
https://example.com/download?course=XYZ&file=frame%3AXYZ%2F{frame}%2Fframe_000{n}.jpg
|
| 140 |
+
|
| 141 |
+
The function yields tuples: (frame_number, url, caption_or_error)
|
| 142 |
+
Continues until a frame fails to download (e.g., 404 error)
|
| 143 |
+
"""
|
| 144 |
+
if "{frame}" not in base_url_template:
|
| 145 |
+
raise ValueError("base_url_template must contain '{frame}' placeholder")
|
| 146 |
+
|
| 147 |
+
consecutive_errors = 0
|
| 148 |
+
MAX_CONSECUTIVE_ERRORS = 3 # Stop after this many consecutive errors
|
| 149 |
+
|
| 150 |
+
i = start
|
| 151 |
+
while True:
|
| 152 |
+
frame_id = str(i).zfill(padding)
|
| 153 |
+
url = base_url_template.format(frame=frame_id)
|
| 154 |
+
try:
|
| 155 |
+
img = download_image(url)
|
| 156 |
+
caption = analyze_image(img)
|
| 157 |
+
consecutive_errors = 0 # Reset error counter on success
|
| 158 |
+
yield (i, url, {"success": True, "caption": caption})
|
| 159 |
+
except requests.exceptions.HTTPError as e:
|
| 160 |
+
if e.response.status_code == 404:
|
| 161 |
+
print(f"[INFO] No more frames found after frame {i-1}")
|
| 162 |
+
break
|
| 163 |
+
yield (i, url, {"success": False, "error": str(e)})
|
| 164 |
+
consecutive_errors += 1
|
| 165 |
+
except Exception as e:
|
| 166 |
+
yield (i, url, {"success": False, "error": str(e)})
|
| 167 |
+
consecutive_errors += 1
|
| 168 |
+
|
| 169 |
+
if consecutive_errors >= MAX_CONSECUTIVE_ERRORS:
|
| 170 |
+
print(f"[INFO] Stopping after {MAX_CONSECUTIVE_ERRORS} consecutive errors")
|
| 171 |
+
break
|
| 172 |
+
|
| 173 |
+
i += 1
|
| 174 |
+
|
| 175 |
def analyze_image(image: Image.Image) -> str:
|
| 176 |
"""Analyze image using Florence-2 model with hardcoded task"""
|
| 177 |
if not processor or not model:
|
|
|
|
| 308 |
|
| 309 |
# ===== Main Execution =====
|
| 310 |
if __name__ == "__main__":
|
| 311 |
+
if not processor or not model:
|
| 312 |
+
print("[ERROR] Model failed to load. Cannot proceed with frame analysis.")
|
| 313 |
+
sys.exit(1)
|
| 314 |
+
|
| 315 |
+
print("[INFO] Starting frame analysis...")
|
| 316 |
+
print(f"[INFO] Using URL template: {BASE_URL_TEMPLATE}")
|
| 317 |
+
print(f"[INFO] Starting from frame {START_FRAME} with {FRAME_PADDING} digit padding")
|
| 318 |
+
|
| 319 |
+
results = []
|
| 320 |
+
for frame_num, url, result in iterate_and_analyze(
|
| 321 |
+
BASE_URL_TEMPLATE,
|
| 322 |
+
start=START_FRAME,
|
| 323 |
+
padding=FRAME_PADDING
|
| 324 |
+
):
|
| 325 |
+
if result["success"]:
|
| 326 |
+
print(f"[SUCCESS] Frame {frame_num}: {result['caption']}")
|
| 327 |
+
results.append({
|
| 328 |
+
"frame": frame_num,
|
| 329 |
+
"url": url,
|
| 330 |
+
"caption": result["caption"]
|
| 331 |
+
})
|
| 332 |
+
else:
|
| 333 |
+
print(f"[ERROR] Frame {frame_num}: {result['error']}")
|
| 334 |
+
results.append({
|
| 335 |
+
"frame": frame_num,
|
| 336 |
+
"url": url,
|
| 337 |
+
"error": result["error"]
|
| 338 |
+
})
|
| 339 |
+
|
| 340 |
+
# Save results to a JSON file
|
| 341 |
+
import json
|
| 342 |
+
output_file = "frame_analysis_results.json"
|
| 343 |
+
with open(output_file, "w", encoding="utf-8") as f:
|
| 344 |
+
json.dump(results, f, indent=2, ensure_ascii=False)
|
| 345 |
+
print(f"[INFO] Results saved to {output_file}")
|
| 346 |
+
|
| 347 |
+
# Optional: start the API server after frame analysis
|
| 348 |
+
start_server = os.getenv("START_SERVER", "false").lower() == "true"
|
| 349 |
+
if start_server:
|
| 350 |
+
port = int(os.getenv("PORT", 7860))
|
| 351 |
+
print(f"[INFO] Starting server on port {port}")
|
| 352 |
+
print(f"[INFO] Task: {TASK}")
|
| 353 |
+
print(f"[INFO] API Documentation: http://localhost:{port}/docs")
|
| 354 |
+
|
| 355 |
+
uvicorn.run(
|
| 356 |
+
app,
|
| 357 |
+
host="0.0.0.0",
|
| 358 |
+
port=port,
|
| 359 |
+
reload=False
|
| 360 |
+
)
|