azizerorahman commited on
Commit
59cb1b6
·
verified ·
1 Parent(s): 8540f8c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +256 -164
app.py CHANGED
@@ -1,12 +1,47 @@
1
- import gradio as gr
 
 
 
 
 
 
 
 
2
  import cv2
3
  import torch
4
  import numpy as np
5
  import tempfile
6
  import os
 
 
 
7
  from siamrpn import TrackerSiamRPN
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
- # Global tracker instance
 
 
 
 
 
 
 
 
 
 
10
  tracker = None
11
  device = None
12
 
@@ -14,37 +49,32 @@ def load_tracker():
14
  """Load the SiamRPN tracker with GPU support"""
15
  global tracker, device
16
  if tracker is None:
 
 
 
17
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
18
- print(f"Loading tracker on {device}...")
19
- tracker = TrackerSiamRPN(net_path='model.pth')
20
- print("✓ Tracker loaded successfully")
21
  return tracker
22
 
23
- def track_video(video_file, bbox_x, bbox_y, bbox_w, bbox_h):
 
24
  """
25
- Track an object in a video using SiamRPN
 
 
 
 
 
 
 
26
  """
27
  try:
28
- # Validate inputs
29
- if video_file is None:
30
- return None, "❌ Please upload a video file"
31
-
32
- bbox_x = int(bbox_x)
33
- bbox_y = int(bbox_y)
34
- bbox_w = int(bbox_w)
35
- bbox_h = int(bbox_h)
36
-
37
- if bbox_w <= 0 or bbox_h <= 0:
38
- return None, "❌ Bounding box width and height must be positive"
39
-
40
- # Load tracker
41
  tracker_instance = load_tracker()
42
- gpu_status = "🚀 GPU (CUDA)" if device.type == 'cuda' else "💻 CPU"
43
 
44
- # Open video
45
- cap = cv2.VideoCapture(video_file)
46
  if not cap.isOpened():
47
- return None, " Failed to open video file"
48
 
49
  # Get video properties
50
  fps = int(cap.get(cv2.CAP_PROP_FPS))
@@ -54,183 +84,245 @@ def track_video(video_file, bbox_x, bbox_y, bbox_w, bbox_h):
54
  height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
55
  total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
56
 
57
- print(f"Video: {width}x{height} @ {fps}fps, {total_frames} frames")
58
- print(f"Bounding box: ({bbox_x}, {bbox_y}, {bbox_w}, {bbox_h})")
59
 
60
- # Read first frame
61
- ret, first_frame = cap.read()
62
  if not ret:
63
- cap.release()
64
- return None, "❌ Failed to read first frame"
65
 
66
  # Validate bounding box
67
- if bbox_x < 0 or bbox_y < 0 or bbox_x + bbox_w > width or bbox_y + bbox_h > height:
68
- cap.release()
69
- return None, f"❌ Bounding box outside frame (frame: {width}x{height})"
 
 
 
 
 
70
 
71
  # Initialize tracker
72
- init_bbox = [bbox_x, bbox_y, bbox_w, bbox_h]
73
- tracker_instance.init(first_frame, init_bbox)
74
 
75
- # Create output file
76
- temp_output = tempfile.NamedTemporaryFile(delete=False, suffix='.mp4')
77
  temp_output.close()
78
- output_path = temp_output.name
79
 
80
- fourcc = cv2.VideoWriter_fourcc(*'mp4v')
81
- out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
 
82
 
83
- if not out.isOpened():
84
- cap.release()
85
- return None, "❌ Failed to create output video"
86
 
87
- # Draw on first frame
88
- first_frame_copy = first_frame.copy()
89
- cv2.rectangle(first_frame_copy, (bbox_x, bbox_y), (bbox_x + bbox_w, bbox_y + bbox_h), (0, 255, 0), 3)
90
- cv2.putText(first_frame_copy, 'Initial Target', (bbox_x, bbox_y - 10),
91
- cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2)
92
- out.write(first_frame_copy)
93
 
94
- # Process frames
95
  frame_count = 1
 
96
  while True:
97
  ret, frame = cap.read()
98
  if not ret:
99
  break
100
 
 
 
 
101
  bbox = tracker_instance.update(frame)
102
- x, y, w, h = [int(v) for v in bbox]
103
 
104
- # Clamp to frame
 
105
  x = max(0, min(x, width - 1))
106
  y = max(0, min(y, height - 1))
107
  w = max(1, min(w, width - x))
108
  h = max(1, min(h, height - y))
109
 
110
- cv2.rectangle(frame, (x, y), (x + w, y + h), (0, 255, 0), 3)
111
- cv2.putText(frame, f'Frame {frame_count}', (10, 30),
112
- cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2)
113
 
114
- out.write(frame)
115
- frame_count += 1
 
 
116
 
117
  cap.release()
118
- out.release()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
119
 
120
- message = f"""
121
- ✅ **Tracking Complete!**
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
122
 
123
- - Processed: {frame_count} frames
124
- - Device: {gpu_status}
125
- - Resolution: {width}x{height}
126
- - FPS: {fps}
127
 
128
- Download your tracked video below! 🎥
129
- """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
130
 
131
- return output_path, message
 
 
 
 
 
 
 
 
 
 
132
 
 
 
133
  except Exception as e:
134
- print(f"Error: {str(e)}")
135
- import traceback
136
- traceback.print_exc()
137
- return None, f"❌ Error: {str(e)}"
138
-
139
- def get_first_frame(video_file):
140
- """Extract first frame"""
141
- if video_file is None:
142
- return None
143
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
144
  try:
145
- cap = cv2.VideoCapture(video_file)
146
- ret, frame = cap.read()
147
- cap.release()
148
-
149
- if ret:
150
- frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
151
- return frame_rgb
152
- return None
153
  except Exception as e:
154
- print(f"Error extracting frame: {e}")
155
- return None
156
 
157
- # Create interface
158
- with gr.Blocks(title="VisioTrack - Object Tracker") as demo:
159
- gr.Markdown("""
160
- # 🎯 VisioTrack - SiamRPN Object Tracker
161
-
162
- Upload a video and specify a bounding box around the object you want to track!
163
- """)
164
-
165
- with gr.Row():
166
- with gr.Column(scale=1):
167
- gr.Markdown("### 📹 Step 1: Upload Video")
168
- video_input = gr.Video(label="Upload Video")
169
-
170
- gr.Markdown("### 🎯 Step 2: Define Bounding Box")
171
- gr.Markdown("Enter coordinates for the object (on first frame)")
172
-
173
- with gr.Row():
174
- bbox_x = gr.Number(label="X (left)", value=100, precision=0)
175
- bbox_y = gr.Number(label="Y (top)", value=100, precision=0)
176
-
177
- with gr.Row():
178
- bbox_w = gr.Number(label="Width", value=200, precision=0)
179
- bbox_h = gr.Number(label="Height", value=200, precision=0)
180
-
181
- gr.Markdown("""
182
- **Tips:**
183
- - X, Y: Top-left corner
184
- - Width, Height: Box size
185
- - Check first frame preview for coordinates
186
- """)
187
-
188
- track_btn = gr.Button("🚀 Start Tracking", variant="primary")
189
-
190
- device_info = gr.Textbox(
191
- label="Device",
192
- value="🚀 GPU" if torch.cuda.is_available() else "💻 CPU",
193
- interactive=False
194
- )
195
-
196
- with gr.Column(scale=1):
197
- gr.Markdown("### 🖼️ First Frame Preview")
198
- first_frame_display = gr.Image(label="First Frame")
199
-
200
- gr.Markdown("### 📥 Output")
201
- status_output = gr.Markdown("Upload a video to begin...")
202
- video_output = gr.Video(label="Tracked Video")
203
-
204
- # Events
205
- video_input.change(
206
- fn=get_first_frame,
207
- inputs=[video_input],
208
- outputs=[first_frame_display]
209
- )
210
-
211
- track_btn.click(
212
- fn=track_video,
213
- inputs=[video_input, bbox_x, bbox_y, bbox_w, bbox_h],
214
- outputs=[video_output, status_output]
215
- )
216
-
217
- gr.Markdown("""
218
- ---
219
- ### 💡 Example Usage
220
-
221
- 1. Upload video
222
- 2. View first frame
223
- 3. Enter bounding box (x, y, width, height)
224
- 4. Click "Start Tracking"
225
- 5. Download result
226
-
227
- ### 📐 Example Coordinates (1920x1080 video)
228
- - Person in center: X=800, Y=300, W=300, H=600
229
- - Car on left: X=200, Y=400, W=400, H=300
230
-
231
- ### 🎯 Best For
232
- Cars 🚗 | People 🚶 | Animals 🐱 | Sports ⚽
233
- """)
234
 
235
  if __name__ == "__main__":
236
- demo.launch()
 
 
1
+ #!/usr/bin/env python
2
+ """
3
+ FastAPI Server for VisioTrack on Hugging Face Spaces
4
+ REST API for object tracking in videos
5
+ """
6
+
7
+ from fastapi import FastAPI, File, UploadFile, Form, HTTPException
8
+ from fastapi.responses import FileResponse, JSONResponse
9
+ from fastapi.middleware.cors import CORSMiddleware
10
  import cv2
11
  import torch
12
  import numpy as np
13
  import tempfile
14
  import os
15
+ import subprocess
16
+ import shutil
17
+ from pathlib import Path
18
  from siamrpn import TrackerSiamRPN
19
+ import logging
20
+
21
+ # Configure logging
22
+ logging.basicConfig(level=logging.INFO)
23
+ logger = logging.getLogger(__name__)
24
+
25
+ # Initialize FastAPI app
26
+ app = FastAPI(
27
+ title="VisioTrack API",
28
+ description="Object tracking API using SiamRPN",
29
+ version="1.0.0",
30
+ docs_url="/", # Swagger UI at root
31
+ redoc_url="/redoc"
32
+ )
33
 
34
+ # Enable CORS for frontend integration
35
+ app.add_middleware(
36
+ CORSMiddleware,
37
+ allow_origins=["*"],
38
+ allow_credentials=True,
39
+ allow_methods=["*"],
40
+ allow_headers=["*"],
41
+ )
42
+
43
+ # Model configuration
44
+ MODEL_PATH = "model.pth"
45
  tracker = None
46
  device = None
47
 
 
49
  """Load the SiamRPN tracker with GPU support"""
50
  global tracker, device
51
  if tracker is None:
52
+ if not os.path.exists(MODEL_PATH):
53
+ raise FileNotFoundError(f"Model file '{MODEL_PATH}' not found!")
54
+
55
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
56
+ tracker = TrackerSiamRPN(net_path=MODEL_PATH)
57
+ logger.info(f"✓ Tracker loaded on {device}")
 
58
  return tracker
59
 
60
+ def process_video_tracking(video_path: str, bbox_x: int, bbox_y: int,
61
+ bbox_w: int, bbox_h: int):
62
  """
63
+ Process video with object tracking
64
+
65
+ Args:
66
+ video_path: Path to input video
67
+ bbox_x, bbox_y, bbox_w, bbox_h: Bounding box coordinates
68
+
69
+ Returns:
70
+ tuple: (output_path, message, metadata)
71
  """
72
  try:
 
 
 
 
 
 
 
 
 
 
 
 
 
73
  tracker_instance = load_tracker()
 
74
 
75
+ cap = cv2.VideoCapture(video_path)
 
76
  if not cap.isOpened():
77
+ return None, "Could not open video file", None
78
 
79
  # Get video properties
80
  fps = int(cap.get(cv2.CAP_PROP_FPS))
 
84
  height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
85
  total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
86
 
87
+ logger.info(f"Video: {width}x{height} @ {fps}fps, {total_frames} frames")
 
88
 
89
+ ret, frame = cap.read()
 
90
  if not ret:
91
+ return None, "Could not read first frame", None
 
92
 
93
  # Validate bounding box
94
+ if bbox_w <= 0 or bbox_h <= 0:
95
+ return None, "Invalid bounding box dimensions", None
96
+
97
+ if (bbox_x < 0 or bbox_y < 0 or
98
+ bbox_x + bbox_w > width or bbox_y + bbox_h > height):
99
+ return None, f"Bounding box out of bounds (frame: {width}x{height})", None
100
+
101
+ bbox = [bbox_x, bbox_y, bbox_w, bbox_h]
102
 
103
  # Initialize tracker
104
+ tracker_instance.init(frame, bbox)
 
105
 
106
+ # Create temporary output file
107
+ temp_output = tempfile.NamedTemporaryFile(delete=False, suffix='_temp.mp4')
108
  temp_output.close()
 
109
 
110
+ # Use XVID codec for initial write
111
+ fourcc = cv2.VideoWriter_fourcc(*'XVID')
112
+ writer = cv2.VideoWriter(temp_output.name, fourcc, fps, (width, height))
113
 
114
+ if not writer.isOpened():
115
+ return None, "Could not create video writer", None
 
116
 
117
+ # Draw first frame with initial bbox
118
+ x, y, w, h = [int(v) for v in bbox]
119
+ cv2.rectangle(frame, (x, y), (x+w, y+h), (0, 255, 0), 3)
120
+ cv2.putText(frame, 'Frame: 1', (10, 30),
121
+ cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2)
122
+ writer.write(frame)
123
 
124
+ # Process remaining frames
125
  frame_count = 1
126
+
127
  while True:
128
  ret, frame = cap.read()
129
  if not ret:
130
  break
131
 
132
+ frame_count += 1
133
+
134
+ # Update tracker
135
  bbox = tracker_instance.update(frame)
 
136
 
137
+ # Draw tracking result
138
+ x, y, w, h = [int(v) for v in bbox]
139
  x = max(0, min(x, width - 1))
140
  y = max(0, min(y, height - 1))
141
  w = max(1, min(w, width - x))
142
  h = max(1, min(h, height - y))
143
 
144
+ cv2.rectangle(frame, (x, y), (x+w, y+h), (0, 255, 0), 3)
145
+ cv2.putText(frame, f'Frame: {frame_count}', (10, 30),
146
+ cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2)
147
 
148
+ writer.write(frame)
149
+
150
+ if frame_count % 30 == 0:
151
+ logger.info(f"Processed {frame_count}/{total_frames} frames")
152
 
153
  cap.release()
154
+ writer.release()
155
+
156
+ # Re-encode with H.264 for browser compatibility
157
+ final_output = tempfile.NamedTemporaryFile(delete=False, suffix='.mp4')
158
+ final_output.close()
159
+
160
+ try:
161
+ logger.info("Re-encoding video for browser compatibility...")
162
+ subprocess.run([
163
+ 'ffmpeg', '-i', temp_output.name,
164
+ '-c:v', 'libx264',
165
+ '-preset', 'fast',
166
+ '-crf', '23',
167
+ '-pix_fmt', 'yuv420p',
168
+ '-movflags', '+faststart',
169
+ '-y',
170
+ final_output.name
171
+ ], check=True, capture_output=True, text=True)
172
+
173
+ os.unlink(temp_output.name)
174
+ logger.info("✓ Video re-encoded successfully")
175
+
176
+ except (subprocess.CalledProcessError, FileNotFoundError) as e:
177
+ logger.warning(f"FFmpeg encoding failed: {e}, using original")
178
+ shutil.move(temp_output.name, final_output.name)
179
+
180
+ metadata = {
181
+ 'frames_processed': frame_count,
182
+ 'resolution': f"{width}x{height}",
183
+ 'fps': fps,
184
+ 'device': str(device)
185
+ }
186
 
187
+ return final_output.name, f"Successfully tracked {frame_count} frames", metadata
188
+
189
+ except Exception as e:
190
+ logger.error(f"Tracking error: {str(e)}")
191
+ return None, f"Error: {str(e)}", None
192
+
193
+
194
+ @app.get("/health")
195
+ async def health_check():
196
+ """
197
+ Health check endpoint (required by HF Spaces)
198
+ """
199
+ return JSONResponse({
200
+ 'status': 'healthy',
201
+ 'gpu_available': torch.cuda.is_available(),
202
+ 'gpu_name': torch.cuda.get_device_name(0) if torch.cuda.is_available() else None,
203
+ 'model_loaded': tracker is not None
204
+ })
205
 
 
 
 
 
206
 
207
+ @app.post("/track")
208
+ async def track_video(
209
+ video: UploadFile = File(..., description="Video file to process"),
210
+ bbox_x: int = Form(..., description="X coordinate of bounding box"),
211
+ bbox_y: int = Form(..., description="Y coordinate of bounding box"),
212
+ bbox_w: int = Form(..., description="Width of bounding box"),
213
+ bbox_h: int = Form(..., description="Height of bounding box")
214
+ ):
215
+ """
216
+ Main tracking endpoint
217
+
218
+ Upload a video and bounding box coordinates to track an object.
219
+ Returns the processed video with tracking visualization.
220
+ """
221
+ temp_input = None
222
+ output_path = None
223
+
224
+ try:
225
+ # Validate file type
226
+ if not video.content_type.startswith('video/'):
227
+ raise HTTPException(status_code=400, detail="File must be a video")
228
+
229
+ # Save uploaded video
230
+ temp_input = tempfile.NamedTemporaryFile(delete=False, suffix='.mp4')
231
+ content = await video.read()
232
+ temp_input.write(content)
233
+ temp_input.close()
234
+
235
+ logger.info(f"Processing video: {video.filename}")
236
+ logger.info(f"Bounding box: ({bbox_x}, {bbox_y}, {bbox_w}, {bbox_h})")
237
+
238
+ # Process video
239
+ output_path, message, metadata = process_video_tracking(
240
+ temp_input.name, bbox_x, bbox_y, bbox_w, bbox_h
241
+ )
242
+
243
+ if output_path is None:
244
+ raise HTTPException(status_code=400, detail=message)
245
 
246
+ # Return processed video
247
+ return FileResponse(
248
+ output_path,
249
+ media_type='video/mp4',
250
+ filename='tracked_video.mp4',
251
+ headers={
252
+ 'X-Frames-Processed': str(metadata['frames_processed']),
253
+ 'X-Resolution': metadata['resolution'],
254
+ 'X-FPS': str(metadata['fps'])
255
+ }
256
+ )
257
 
258
+ except HTTPException:
259
+ raise
260
  except Exception as e:
261
+ logger.error(f"Error: {str(e)}")
262
+ raise HTTPException(status_code=500, detail=str(e))
 
 
 
 
 
 
 
263
 
264
+ finally:
265
+ # Cleanup temporary files
266
+ if temp_input and os.path.exists(temp_input.name):
267
+ try:
268
+ os.unlink(temp_input.name)
269
+ except:
270
+ pass
271
+
272
+
273
+ @app.get("/info")
274
+ async def get_info():
275
+ """
276
+ Get API information and usage instructions
277
+ """
278
+ return {
279
+ 'name': 'VisioTrack API',
280
+ 'version': '1.0.0',
281
+ 'description': 'Object tracking API using SiamRPN',
282
+ 'endpoints': {
283
+ '/health': 'Health check',
284
+ '/track': 'Track object in video (POST with multipart/form-data)',
285
+ '/info': 'API information',
286
+ '/': 'Interactive API documentation (Swagger UI)'
287
+ },
288
+ 'usage': {
289
+ 'method': 'POST',
290
+ 'endpoint': '/track',
291
+ 'content_type': 'multipart/form-data',
292
+ 'parameters': {
293
+ 'video': 'Video file',
294
+ 'bbox_x': 'X coordinate (int)',
295
+ 'bbox_y': 'Y coordinate (int)',
296
+ 'bbox_w': 'Width (int)',
297
+ 'bbox_h': 'Height (int)'
298
+ }
299
+ },
300
+ 'example_curl': '''
301
+ curl -X POST "https://your-space.hf.space/track" \\
302
+ -F "video=@video.mp4" \\
303
+ -F "bbox_x=100" \\
304
+ -F "bbox_y=100" \\
305
+ -F "bbox_w=200" \\
306
+ -F "bbox_h=200" \\
307
+ -o tracked_video.mp4
308
+ '''
309
+ }
310
+
311
+
312
+ @app.on_event("startup")
313
+ async def startup_event():
314
+ """Load model on startup"""
315
+ logger.info("=" * 50)
316
+ logger.info("VisioTrack FastAPI Server Starting...")
317
+ logger.info("=" * 50)
318
  try:
319
+ load_tracker()
320
+ logger.info("✓ Model loaded successfully")
 
 
 
 
 
 
321
  except Exception as e:
322
+ logger.error(f" Failed to load model: {e}")
323
+ logger.info("=" * 50)
324
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
325
 
326
  if __name__ == "__main__":
327
+ import uvicorn
328
+ uvicorn.run(app, host="0.0.0.0", port=7860)