Fred808 commited on
Commit
d059378
·
verified ·
1 Parent(s): 03901aa

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +195 -81
app.py CHANGED
@@ -1,100 +1,214 @@
1
  import os
2
- import cv2
3
  import torch
4
- from pathlib import Path
5
  from PIL import Image
 
 
 
6
  from transformers import AutoProcessor, AutoModelForCausalLM
 
7
 
8
  # ===== CONFIG =====
9
- VIDEO_PATH = "How.mp4" # Local video file in root
10
- FRAMES_DIR = "extracted" # Where frames are stored
11
- FPS = 3 # Frames to extract per second
12
  DEVICE = "cpu" # Use CPU for compatibility
13
  RESIZE_DIM = (512, 512) # Resize images to this resolution
 
14
 
15
- # ===== Ensure Output Directory =====
16
- def ensure_dir(path):
17
- Path(path).mkdir(parents=True, exist_ok=True)
 
 
 
18
 
19
- # ===== Frame Extraction Function =====
20
- def extract_frames(video_path, output_dir, fps=3):
21
- ensure_dir(output_dir)
22
- cap = cv2.VideoCapture(str(video_path))
23
- if not cap.isOpened():
24
- print(f"[ERROR] Failed to open video file: {video_path}")
25
- return []
26
 
27
- video_fps = cap.get(cv2.CAP_PROP_FPS)
28
- if not video_fps or video_fps <= 0:
29
- print("[WARN] Using fallback FPS: 30")
30
- video_fps = 30
31
- frame_interval = int(round(video_fps / fps))
32
-
33
- frame_idx = 0
34
- saved_idx = 1
35
- total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
36
- frame_paths = []
37
-
38
- while cap.isOpened():
39
- ret, frame = cap.read()
40
- if not ret:
41
- break
42
- if frame_idx % frame_interval == 0:
43
- frame_name = f"{saved_idx:04d}.png"
44
- output_path = Path(output_dir) / frame_name
45
- cv2.imwrite(str(output_path), frame)
46
- frame_paths.append(str(output_path))
47
- print(f"[INFO] Saved frame {frame_idx} -> {frame_name}")
48
- saved_idx += 1
49
- frame_idx += 1
50
- cap.release()
51
- return frame_paths
52
 
53
  # ===== Load Florence-2 Base Model =====
54
  print("[INFO] Loading Florence-2-base model on CPU...")
55
- processor = AutoProcessor.from_pretrained("microsoft/Florence-2-base", trust_remote_code=True)
56
- model = AutoModelForCausalLM.from_pretrained(
57
- "microsoft/Florence-2-base",
58
- trust_remote_code=True,
59
- attn_implementation="eager"
60
- ).to(DEVICE).eval()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
 
62
- # ===== Analyze a Frame =====
63
- def analyze_frame(image_path):
64
- image = Image.open(image_path).convert("RGB")
65
- image = image.resize(RESIZE_DIM, Image.BILINEAR) # Resize for speed
66
- inputs = processor(
67
- text="<MORE_DETAILED_CAPTION>",
68
- images=image,
69
- return_tensors="pt"
70
- ).to(DEVICE)
71
- with torch.no_grad():
72
- generated_ids = model.generate(
73
- input_ids=inputs["input_ids"],
74
- pixel_values=inputs["pixel_values"],
75
- max_new_tokens=1024,
76
- num_beams=3,
77
- do_sample=False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
  )
79
- generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
80
- result = processor.post_process_generation(
81
- generated_text,
82
- task="<MORE_DETAILED_CAPTION>",
83
- image_size=RESIZE_DIM
84
- )
85
- return result["<MORE_DETAILED_CAPTION>"]
86
 
87
- # ===== Main Execution =====
88
- if __name__ == "__main__":
89
- frame_list = extract_frames(VIDEO_PATH, FRAMES_DIR, FPS)
90
- print(f"[INFO] Extracted {len(frame_list)} frames.")
 
 
 
 
 
91
 
92
- for idx, frame_path in enumerate(frame_list):
93
- print(f"\n[FRAME {idx+1}] Analyzing: {frame_path}")
94
- caption = analyze_frame(frame_path)
95
- print(f"[RESULT] {caption}")
 
 
 
 
 
 
 
 
 
 
 
 
96
 
97
- # Optional: Start a dummy Uvicorn server (if you want to expand into an API later)
98
- import uvicorn
99
- port = int(os.getenv("PORT", 7860)) # for Gradio Spaces compatibility
100
- uvicorn.run("main:app", host="0.0.0.0", port=port) if os.getenv("RUN_SERVER") else None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
 
2
  import torch
3
+ import requests
4
  from PIL import Image
5
+ from io import BytesIO
6
+ from fastapi import FastAPI, HTTPException
7
+ from pydantic import BaseModel, HttpUrl
8
  from transformers import AutoProcessor, AutoModelForCausalLM
9
+ import uvicorn
10
 
11
  # ===== CONFIG =====
 
 
 
12
  DEVICE = "cpu" # Use CPU for compatibility
13
  RESIZE_DIM = (512, 512) # Resize images to this resolution
14
+ MAX_IMAGE_SIZE = 10 * 1024 * 1024 # 10MB max image size
15
 
16
+ # ===== FastAPI App =====
17
+ app = FastAPI(
18
+ title="Florence-2 Image Analysis API",
19
+ description="Analyze images using Microsoft's Florence-2 model",
20
+ version="1.0.0"
21
+ )
22
 
23
+ # ===== Request/Response Models =====
24
+ class ImageAnalysisRequest(BaseModel):
25
+ image_url: HttpUrl
26
+ task: str = "<MORE_DETAILED_CAPTION>" # Default task
 
 
 
27
 
28
+ class ImageAnalysisResponse(BaseModel):
29
+ caption: str
30
+ success: bool
31
+ error_message: str = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
 
33
  # ===== Load Florence-2 Base Model =====
34
  print("[INFO] Loading Florence-2-base model on CPU...")
35
+ try:
36
+ processor = AutoProcessor.from_pretrained("microsoft/Florence-2-base", trust_remote_code=True)
37
+ model = AutoModelForCausalLM.from_pretrained(
38
+ "microsoft/Florence-2-base",
39
+ trust_remote_code=True,
40
+ attn_implementation="eager"
41
+ ).to(DEVICE).eval()
42
+ print("[INFO] Model loaded successfully!")
43
+ except Exception as e:
44
+ print(f"[ERROR] Failed to load model: {e}")
45
+ processor = None
46
+ model = None
47
+
48
+ # ===== Helper Functions =====
49
+ def download_image(url: str) -> Image.Image:
50
+ """Download image from URL and return PIL Image"""
51
+ try:
52
+ # Set headers to mimic browser request
53
+ headers = {
54
+ 'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36'
55
+ }
56
+
57
+ response = requests.get(str(url), headers=headers, timeout=30)
58
+ response.raise_for_status()
59
+
60
+ # Check content length
61
+ if len(response.content) > MAX_IMAGE_SIZE:
62
+ raise ValueError(f"Image too large: {len(response.content)} bytes (max: {MAX_IMAGE_SIZE})")
63
+
64
+ # Check if content is actually an image
65
+ content_type = response.headers.get('content-type', '')
66
+ if not content_type.startswith('image/'):
67
+ raise ValueError(f"URL does not point to an image. Content-Type: {content_type}")
68
+
69
+ image = Image.open(BytesIO(response.content)).convert("RGB")
70
+ return image
71
+
72
+ except requests.exceptions.RequestException as e:
73
+ raise ValueError(f"Failed to download image: {e}")
74
+ except Exception as e:
75
+ raise ValueError(f"Failed to process image: {e}")
76
 
77
+ def analyze_image(image: Image.Image, task: str = "<MORE_DETAILED_CAPTION>") -> str:
78
+ """Analyze image using Florence-2 model"""
79
+ if not processor or not model:
80
+ raise ValueError("Model not loaded properly")
81
+
82
+ try:
83
+ # Resize image for faster processing
84
+ image = image.resize(RESIZE_DIM, Image.BILINEAR)
85
+
86
+ # Prepare inputs
87
+ inputs = processor(
88
+ text=task,
89
+ images=image,
90
+ return_tensors="pt"
91
+ ).to(DEVICE)
92
+
93
+ # Generate caption
94
+ with torch.no_grad():
95
+ generated_ids = model.generate(
96
+ input_ids=inputs["input_ids"],
97
+ pixel_values=inputs["pixel_values"],
98
+ max_new_tokens=1024,
99
+ num_beams=3,
100
+ do_sample=False
101
+ )
102
+
103
+ # Decode and post-process
104
+ generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
105
+ result = processor.post_process_generation(
106
+ generated_text,
107
+ task=task,
108
+ image_size=RESIZE_DIM
109
  )
110
+
111
+ return result[task]
112
+
113
+ except Exception as e:
114
+ raise ValueError(f"Failed to analyze image: {e}")
 
 
115
 
116
+ # ===== API Endpoints =====
117
+ @app.get("/")
118
+ async def root():
119
+ """Health check endpoint"""
120
+ return {
121
+ "message": "Florence-2 Image Analysis API",
122
+ "status": "running",
123
+ "model_loaded": processor is not None and model is not None
124
+ }
125
 
126
+ @app.get("/health")
127
+ async def health_check():
128
+ """Detailed health check"""
129
+ return {
130
+ "status": "healthy" if (processor and model) else "unhealthy",
131
+ "model_loaded": processor is not None and model is not None,
132
+ "device": DEVICE,
133
+ "available_tasks": [
134
+ "<MORE_DETAILED_CAPTION>",
135
+ "<DETAILED_CAPTION>",
136
+ "<CAPTION>",
137
+ "<OD>", # Object Detection
138
+ "<DENSE_REGION_CAPTION>",
139
+ "<REGION_PROPOSAL>"
140
+ ]
141
+ }
142
 
143
+ @app.post("/analyze", response_model=ImageAnalysisResponse)
144
+ async def analyze_image_endpoint(request: ImageAnalysisRequest):
145
+ """
146
+ Analyze an image from a URL using Florence-2 model
147
+
148
+ Available tasks:
149
+ - <MORE_DETAILED_CAPTION>: Generate detailed image captions
150
+ - <DETAILED_CAPTION>: Generate detailed captions
151
+ - <CAPTION>: Generate basic captions
152
+ - <OD>: Object detection
153
+ - <DENSE_REGION_CAPTION>: Dense region captioning
154
+ - <REGION_PROPOSAL>: Region proposal
155
+ """
156
+ try:
157
+ # Validate task
158
+ valid_tasks = [
159
+ "<MORE_DETAILED_CAPTION>", "<DETAILED_CAPTION>", "<CAPTION>",
160
+ "<OD>", "<DENSE_REGION_CAPTION>", "<REGION_PROPOSAL>"
161
+ ]
162
+ if request.task not in valid_tasks:
163
+ raise HTTPException(
164
+ status_code=400,
165
+ detail=f"Invalid task. Available tasks: {valid_tasks}"
166
+ )
167
+
168
+ # Download and process image
169
+ print(f"[INFO] Processing image from: {request.image_url}")
170
+ image = download_image(request.image_url)
171
+ print(f"[INFO] Image downloaded successfully: {image.size}")
172
+
173
+ # Analyze image
174
+ caption = analyze_image(image, request.task)
175
+ print(f"[INFO] Analysis complete: {caption}")
176
+
177
+ return ImageAnalysisResponse(
178
+ caption=caption,
179
+ success=True
180
+ )
181
+
182
+ except ValueError as e:
183
+ print(f"[ERROR] ValueError: {e}")
184
+ return ImageAnalysisResponse(
185
+ caption="",
186
+ success=False,
187
+ error_message=str(e)
188
+ )
189
+ except Exception as e:
190
+ print(f"[ERROR] Unexpected error: {e}")
191
+ raise HTTPException(status_code=500, detail=f"Internal server error: {e}")
192
+
193
+ @app.get("/analyze")
194
+ async def analyze_image_get(image_url: str, task: str = "<MORE_DETAILED_CAPTION>"):
195
+ """
196
+ GET endpoint for quick image analysis
197
+ Usage: /analyze?image_url=https://example.com/image.jpg&task=<MORE_DETAILED_CAPTION>
198
+ """
199
+ request = ImageAnalysisRequest(image_url=image_url, task=task)
200
+ return await analyze_image_endpoint(request)
201
+
202
+ # ===== Main Execution =====
203
+ if __name__ == "__main__":
204
+ port = int(os.getenv("PORT", 7860))
205
+ print(f"[INFO] Starting server on port {port}")
206
+ print(f"[INFO] Model status: {'Loaded' if (processor and model) else 'Failed to load'}")
207
+ print(f"[INFO] API Documentation: http://localhost:{port}/docs")
208
+
209
+ uvicorn.run(
210
+ "main:app",
211
+ host="0.0.0.0",
212
+ port=port,
213
+ reload=False # Set to True for development
214
+ )