Spaces:
Build error
Build error
| """ | |
| Optical Flow Server for Hugging Face Spaces | |
| ============================================ | |
| Complete server for cloud-based optical flow processing. | |
| """ | |
| import gradio as gr | |
| import cv2 | |
| import numpy as np | |
| import base64 | |
| from io import BytesIO | |
| from PIL import Image | |
| from sklearn.cluster import DBSCAN | |
| import colorsys | |
| from collections import deque | |
| import uuid | |
| import time | |
| import json | |
| class ServerOpticalFlowEngine: | |
| """Optical Flow Engine for server-side processing with persistent tracking.""" | |
| def __init__(self): | |
| self.prev_gray = None | |
| self.trail_points = {} | |
| self.max_trail_length = 30 | |
| self.bg_subtractor = cv2.createBackgroundSubtractorMOG2( | |
| history=500, varThreshold=16, detectShadows=True | |
| ) | |
| self.color_pool = self._generate_colors(20) | |
| # Persistent object tracking | |
| self.tracked_objects = {} | |
| self.next_object_id = 1 | |
| self.max_frames_missing = 15 | |
| def _generate_colors(self, n): | |
| colors = [] | |
| for i in range(n): | |
| hue = i / n | |
| rgb = colorsys.hsv_to_rgb(hue, 0.9, 0.95) | |
| bgr = (int(rgb[2] * 255), int(rgb[1] * 255), int(rgb[0] * 255)) | |
| colors.append(bgr) | |
| return colors | |
| def compute_flow(self, gray): | |
| if self.prev_gray is None: | |
| self.prev_gray = gray.copy() | |
| return None | |
| flow = cv2.calcOpticalFlowFarneback( | |
| self.prev_gray, gray, None, | |
| pyr_scale=0.5, levels=5, winsize=13, | |
| iterations=10, poly_n=5, poly_sigma=1.1, | |
| flags=cv2.OPTFLOW_FARNEBACK_GAUSSIAN | |
| ) | |
| self.prev_gray = gray.copy() | |
| return flow | |
| def segment_motion(self, frame): | |
| fg_mask = self.bg_subtractor.apply(frame) | |
| fg_mask[fg_mask == 127] = 0 | |
| kernel_small = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3)) | |
| fg_mask = cv2.erode(fg_mask, kernel_small, iterations=2) | |
| kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (7, 7)) | |
| fg_mask = cv2.morphologyEx(fg_mask, cv2.MORPH_OPEN, kernel) | |
| kernel_large = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (15, 15)) | |
| fg_mask = cv2.morphologyEx(fg_mask, cv2.MORPH_CLOSE, kernel_large) | |
| fg_mask = cv2.dilate(fg_mask, kernel_small, iterations=1) | |
| return fg_mask | |
| def compute_iou(self, bbox1, bbox2): | |
| """Compute IoU between two bounding boxes.""" | |
| x1_1, y1_1, x2_1, y2_1 = bbox1 | |
| x1_2, y1_2, x2_2, y2_2 = bbox2 | |
| xi1, yi1 = max(x1_1, x1_2), max(y1_1, y1_2) | |
| xi2, yi2 = min(x2_1, x2_2), min(y2_1, y2_2) | |
| if xi2 <= xi1 or yi2 <= yi1: | |
| return 0.0 | |
| inter = (xi2 - xi1) * (yi2 - yi1) | |
| area1 = (x2_1 - x1_1) * (y2_1 - y1_1) | |
| area2 = (x2_2 - x1_2) * (y2_2 - y1_2) | |
| return inter / (area1 + area2 - inter) if (area1 + area2 - inter) > 0 else 0.0 | |
| def match_and_track(self, detected): | |
| """Match detections to tracked objects using IoU.""" | |
| if not detected: | |
| for oid in list(self.tracked_objects.keys()): | |
| self.tracked_objects[oid]['missing'] += 1 | |
| if self.tracked_objects[oid]['missing'] > self.max_frames_missing: | |
| del self.tracked_objects[oid] | |
| if oid in self.trail_points: | |
| del self.trail_points[oid] | |
| return {} | |
| det_list = list(detected.values()) | |
| tracked_ids = list(self.tracked_objects.keys()) | |
| matched_t, matched_d, result = set(), set(), {} | |
| for tid in tracked_ids: | |
| best_iou, best_idx = 0.3, None | |
| for i, det in enumerate(det_list): | |
| if i in matched_d: | |
| continue | |
| iou = self.compute_iou(self.tracked_objects[tid]['bbox'], det['bbox']) | |
| if iou > best_iou: | |
| best_iou, best_idx = iou, i | |
| if best_idx is not None: | |
| matched_t.add(tid) | |
| matched_d.add(best_idx) | |
| det = det_list[best_idx] | |
| self.tracked_objects[tid].update({ | |
| 'centroid': det['centroid'], 'bbox': det['bbox'], | |
| 'velocity': det['velocity'], 'area': det['area'], | |
| 'contour': det.get('contour'), 'missing': 0 | |
| }) | |
| result[tid] = self.tracked_objects[tid] | |
| for tid in tracked_ids: | |
| if tid not in matched_t: | |
| self.tracked_objects[tid]['missing'] += 1 | |
| if self.tracked_objects[tid]['missing'] > self.max_frames_missing: | |
| del self.tracked_objects[tid] | |
| if tid in self.trail_points: | |
| del self.trail_points[tid] | |
| for i, det in enumerate(det_list): | |
| if i not in matched_d: | |
| nid = self.next_object_id | |
| self.next_object_id += 1 | |
| self.tracked_objects[nid] = { | |
| 'centroid': det['centroid'], 'bbox': det['bbox'], | |
| 'velocity': det['velocity'], 'area': det['area'], | |
| 'contour': det.get('contour'), 'missing': 0 | |
| } | |
| result[nid] = self.tracked_objects[nid] | |
| return result | |
| def detect_objects(self, mask, flow=None, min_area=2000): | |
| contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) | |
| if not contours: | |
| return None, self.match_and_track({}) | |
| label_image = np.full(mask.shape, -1, dtype=np.int32) | |
| objects = {} | |
| label_id = 0 | |
| for contour in contours: | |
| area = cv2.contourArea(contour) | |
| if area < min_area: | |
| continue | |
| x, y, w, h = cv2.boundingRect(contour) | |
| aspect_ratio = float(h) / w if w > 0 else 0 | |
| if aspect_ratio < 0.2 or aspect_ratio > 8.0: | |
| continue | |
| hull = cv2.convexHull(contour) | |
| hull_area = cv2.contourArea(hull) | |
| if hull_area > 0 and (area / hull_area) < 0.3: | |
| continue | |
| cv2.drawContours(label_image, [contour], -1, label_id, -1) | |
| M = cv2.moments(contour) | |
| cx = int(M['m10'] / M['m00']) if M['m00'] > 0 else x + w // 2 | |
| cy = int(M['m01'] / M['m00']) if M['m00'] > 0 else y + h // 2 | |
| avg_velocity = (0, 0) | |
| if flow is not None: | |
| contour_mask = np.zeros(mask.shape, dtype=np.uint8) | |
| cv2.drawContours(contour_mask, [contour], -1, 255, -1) | |
| flow_x = flow[:, :, 0][contour_mask > 0] | |
| flow_y = flow[:, :, 1][contour_mask > 0] | |
| if len(flow_x) > 0: | |
| avg_velocity = (float(np.mean(flow_x)), float(np.mean(flow_y))) | |
| objects[label_id] = { | |
| 'centroid': (cx, cy), 'bbox': (x, y, x + w, y + h), | |
| 'velocity': avg_velocity, 'area': area, 'contour': contour | |
| } | |
| label_id += 1 | |
| # Apply persistent tracking | |
| tracked = self.match_and_track(objects) | |
| return label_image, tracked | |
| def update_trails(self, objects): | |
| for obj_id, obj in objects.items(): | |
| if obj_id not in self.trail_points: | |
| self.trail_points[obj_id] = deque(maxlen=self.max_trail_length) | |
| self.trail_points[obj_id].append(obj['centroid']) | |
| def draw_results(self, frame, objects, label_image): | |
| output = frame.copy() | |
| for label, obj in objects.items(): | |
| color = self.color_pool[label % len(self.color_pool)] | |
| x1, y1, x2, y2 = obj['bbox'] | |
| cv2.rectangle(output, (x1, y1), (x2, y2), color, 2) | |
| cx, cy = obj['centroid'] | |
| cv2.circle(output, (cx, cy), 5, color, -1) | |
| vx, vy = obj['velocity'] | |
| if np.sqrt(vx**2 + vy**2) > 1: | |
| cv2.arrowedLine(output, (cx, cy), | |
| (int(cx + vx * 5), int(cy + vy * 5)), | |
| (0, 255, 255), 3, tipLength=0.3) | |
| cv2.putText(output, f"Obj {label}", (x1, y1 - 25), | |
| cv2.FONT_HERSHEY_SIMPLEX, 0.6, color, 2) | |
| if label_image is not None: | |
| mask = (label_image == label).astype(np.uint8) * 255 | |
| contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) | |
| cv2.drawContours(output, contours, -1, color, 2) | |
| return output | |
| def draw_trails(self, frame): | |
| output = frame.copy() | |
| for obj_id, trail in self.trail_points.items(): | |
| if len(trail) < 2: | |
| continue | |
| color = self.color_pool[obj_id % len(self.color_pool)] | |
| points = np.array(list(trail), dtype=np.int32) | |
| for i in range(1, len(points)): | |
| cv2.line(output, tuple(points[i-1]), tuple(points[i]), color, int(1 + (i / len(points)) * 3)) | |
| cv2.circle(output, tuple(points[-1]), 6, color, -1) | |
| return output | |
| def compute_heatmap(self, flow): | |
| if flow is None: | |
| return None | |
| magnitude = np.sqrt(flow[..., 0]**2 + flow[..., 1]**2) | |
| normalized = cv2.normalize(magnitude, None, 0, 255, cv2.NORM_MINMAX) | |
| return cv2.applyColorMap(normalized.astype(np.uint8), cv2.COLORMAP_HOT) | |
| def process_frame(self, frame): | |
| gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY) | |
| flow = self.compute_flow(gray) | |
| motion_mask = self.segment_motion(frame) | |
| results = {'num_objects': 0, 'tracked': frame.copy(), | |
| 'trails': frame.copy(), 'heatmap': np.zeros_like(frame)} | |
| if flow is not None: | |
| results['heatmap'] = self.compute_heatmap(flow) | |
| label_image, objects = self.detect_objects(motion_mask, flow) | |
| if objects: | |
| results['num_objects'] = len(objects) | |
| self.update_trails(objects) | |
| results['tracked'] = self.draw_results(frame, objects, label_image) | |
| results['trails'] = self.draw_trails(results['tracked']) | |
| return results | |
| def reset(self): | |
| self.prev_gray = None | |
| self.trail_points = {} | |
| self.tracked_objects = {} | |
| self.next_object_id = 1 | |
| # Session management | |
| sessions = {} | |
| SESSION_TIMEOUT = 300 | |
| def get_or_create_session(session_id): | |
| current_time = time.time() | |
| expired = [sid for sid, data in sessions.items() if current_time - data['last_access'] > SESSION_TIMEOUT] | |
| for sid in expired: | |
| del sessions[sid] | |
| if session_id not in sessions: | |
| sessions[session_id] = {'engine': ServerOpticalFlowEngine(), 'last_access': current_time} | |
| else: | |
| sessions[session_id]['last_access'] = current_time | |
| return sessions[session_id]['engine'] | |
| def decode_frame(base64_data): | |
| img_data = base64.b64decode(base64_data) | |
| img = Image.open(BytesIO(img_data)) | |
| return cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR) | |
| def encode_frame(frame): | |
| _, buffer = cv2.imencode('.jpg', frame, [cv2.IMWRITE_JPEG_QUALITY, 85]) | |
| return base64.b64encode(buffer).decode('utf-8') | |
| # API Functions exposed via Gradio | |
| def process_frame_api(frame_base64: str, session_id: str) -> str: | |
| """Process a frame and return JSON result.""" | |
| try: | |
| engine = get_or_create_session(session_id) | |
| frame = decode_frame(frame_base64) | |
| results = engine.process_frame(frame) | |
| return json.dumps({ | |
| 'success': True, 'num_objects': results['num_objects'], | |
| 'tracked': encode_frame(results['tracked']), | |
| 'trails': encode_frame(results['trails']), | |
| 'heatmap': encode_frame(results['heatmap']) | |
| }) | |
| except Exception as e: | |
| return json.dumps({'success': False, 'error': str(e)}) | |
| def reset_session_api(session_id: str) -> str: | |
| """Reset session.""" | |
| if session_id in sessions: | |
| sessions[session_id]['engine'].reset() | |
| return json.dumps({'success': True}) | |
| def create_new_session() -> str: | |
| """Create new session.""" | |
| session_id = str(uuid.uuid4()) | |
| get_or_create_session(session_id) | |
| return session_id | |
| def test_image(image): | |
| """Test with uploaded image.""" | |
| if image is None: | |
| return None | |
| engine = ServerOpticalFlowEngine() | |
| engine.process_frame(image) | |
| return engine.process_frame(image)['tracked'] | |
| # Gradio UI with exposed API endpoints | |
| with gr.Blocks(title="Optical Flow Server") as demo: | |
| gr.Markdown("# 🎥 Optical Flow Processing Server") | |
| gr.Markdown("Use: `python optical_flow_advanced.py --server " + | |
| "https://tremick-visual-odometry.hf.space`") | |
| with gr.Tab("Test"): | |
| with gr.Row(): | |
| input_img = gr.Image(label="Upload Image", type="numpy") | |
| output_img = gr.Image(label="Processed") | |
| input_img.change(test_image, inputs=input_img, outputs=output_img) | |
| with gr.Tab("API"): | |
| gr.Markdown("### API Endpoints") | |
| gr.Markdown("Use the functions below via Gradio Client:") | |
| # Expose API functions | |
| session_btn = gr.Button("Create Session") | |
| session_out = gr.Textbox(label="Session ID") | |
| session_btn.click(create_new_session, outputs=session_out) | |
| with gr.Row(): | |
| frame_input = gr.Textbox(label="Frame (base64)", lines=3) | |
| sid_input = gr.Textbox(label="Session ID") | |
| process_btn = gr.Button("Process Frame") | |
| result_out = gr.Textbox(label="Result (JSON)", lines=5) | |
| process_btn.click(process_frame_api, inputs=[frame_input, sid_input], outputs=result_out) | |
| # Also create simple API interface for programmatic access | |
| api_interface = gr.Interface( | |
| fn=process_frame_api, | |
| inputs=[gr.Textbox(label="frame_base64"), gr.Textbox(label="session_id")], | |
| outputs=gr.Textbox(label="result"), | |
| api_name="process_frame" | |
| ) | |
| demo = gr.TabbedInterface( | |
| [demo, api_interface], | |
| ["UI", "API"] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch(server_name="0.0.0.0", server_port=7860) | |