Spaces:
Build error
Build error
| import json | |
| import logging | |
| from PyThreadKiller import PyThreadKiller | |
| from ultralytics import YOLO | |
| from flask import Flask, request, Response | |
| from flask_cors import CORS | |
| import threading | |
| from utils import start_html_stream, get_mjpeg_frames, get_webpage_frames | |
| import cv2 | |
| from ultralytics import solutions | |
| import time | |
| import torch | |
| logging.getLogger('ultralytics').setLevel(logging.WARNING) | |
| PORT = 7860 | |
| app = Flask(__name__) | |
| CORS(app, resources={r"/*": {"origins": "*"}}) | |
| sessions = {} | |
| def process_frames(stream_url, model, should_track, is_locked, frames_per_second, classes, regions, output_frames, output_data): | |
| cap = cv2.VideoCapture(stream_url) | |
| get_frames = get_mjpeg_frames if cap.isOpened() else get_webpage_frames | |
| cap.release() | |
| counter = None | |
| if model == 'regions': | |
| counter = solutions.RegionCounter( | |
| show=False, | |
| region=regions, | |
| model='models/yolo11n.pt', | |
| classes=classes, | |
| persist=should_track | |
| ) | |
| frame_interval = 1 / frames_per_second | |
| for frame in get_frames(stream_url): | |
| start_time = time.time() | |
| with is_locked: | |
| if model is not None and len(classes) > 0: | |
| if model == 'regions': | |
| region_tracked_frame = counter.count(frame) | |
| output_data.append({ | |
| 'objects': [ | |
| { | |
| 'box': box.xyxy.tolist(), | |
| 'label': counter.model.names[int(box.cls)], | |
| 'confidence': box.conf.item(), | |
| } | |
| for box in counter.track_data | |
| ], | |
| # Always return 0 : need to find out | |
| # 'counter': counter.counting_regions.pop()['counts'] | |
| }) | |
| output_frames.insert(0, region_tracked_frame) | |
| else: | |
| annotated_frame = model.track(frame, classes=classes) if should_track else model.predict(frame, classes=classes) | |
| output_data.append({ | |
| 'objects': [ | |
| { | |
| 'box': box.xyxy.tolist(), | |
| 'label': model.names[int(box.cls)], | |
| 'confidence': box.conf.item() | |
| } | |
| for box in annotated_frame[0].boxes | |
| ] | |
| }) | |
| output_frames.insert(0, annotated_frame[0].plot()) | |
| else: | |
| output_frames.insert(0, frame) | |
| elapsed_time = time.time() - start_time | |
| time.sleep(max(0, frame_interval - elapsed_time)) | |
| def html_stream(): | |
| stream_url = request.args.get('url', '') | |
| stream_url = stream_url if stream_url.startswith(('http://', 'https://')) else 'http://' + stream_url | |
| model_type = request.args.get('model', 'detection') | |
| should_track = request.args.get('tracking', 'false').lower() == 'true' | |
| classes = request.args.get('classes', '') | |
| classes = list(map(int, classes.split(','))) if classes else [] | |
| regions = request.args.get('regions', '') | |
| regions = {str(i+1): [tuple(map(int, point.split(','))) for point in part.split('~')] for i, part in enumerate(regions.split('_'))} if regions else None | |
| session_id = f"{stream_url}_{model_type}_{should_track}_{''.join(map(str, classes))}" | |
| model = { | |
| 'detection': YOLO('models/yolo11n.pt'), | |
| 'segmentation': YOLO('models/yolo11n-seg.pt'), | |
| 'pose': YOLO('models/yolo11n-pose.pt'), | |
| 'orientation': YOLO('models/yolo11n-obb.pt'), | |
| 'classification': YOLO('models/yolo11n-cls.pt'), | |
| 'regions': 'regions' if regions else YOLO('models/yolo11n.pt'), | |
| 'preview': None | |
| }.get(model_type, None) | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| model.to(device) | |
| if session_id not in sessions: | |
| sessions[session_id] = { | |
| 'frame': [None], | |
| 'lock': threading.Lock(), | |
| 'threads': [], | |
| 'data': [] | |
| } | |
| process_thread = PyThreadKiller(target=process_frames, args=(stream_url, model, should_track, sessions[session_id]['lock'], 10, classes, regions, sessions[session_id]['frame'], sessions[session_id]['data'])) | |
| process_thread.daemon = True | |
| sessions[session_id]['thread'] = process_thread | |
| process_thread.start() | |
| return start_html_stream(sessions[session_id]['frame'], sessions[session_id]['lock']) | |
| def kill_session(): | |
| stream_url = request.args.get('url', '') | |
| stream_url = stream_url if stream_url.startswith(('http://', 'https://')) else 'http://' + stream_url | |
| model_type = request.args.get('model', 'detection') | |
| should_track = request.args.get('tracking', 'false').lower() == 'true' | |
| classes = request.args.get('classes', '') | |
| classes = list(map(int, classes.split(','))) if classes else [] | |
| session_id = f"{stream_url}_{model_type}_{should_track}_{''.join(map(str, classes))}" | |
| if session_id in sessions: | |
| with sessions[session_id]['lock']: | |
| sessions[session_id]['thread'].kill() | |
| del sessions[session_id] | |
| return f"Session {session_id} killed", 200 | |
| else: | |
| return f"Session {session_id} not found", 204 | |
| def data_stream(): | |
| stream_url = request.args.get('url', '') | |
| stream_url = stream_url if stream_url.startswith(('http://', 'https://')) else 'http://' + stream_url | |
| model_type = request.args.get('model', 'detection') | |
| should_track = request.args.get('tracking', 'false').lower() == 'true' | |
| classes = request.args.get('classes', '') | |
| classes = list(map(int, classes.split(','))) if classes else [] | |
| session_id = f"{stream_url}_{model_type}_{should_track}_{''.join(map(str, classes))}" | |
| def generate(): | |
| while True: | |
| with sessions[session_id]['lock']: | |
| if sessions[session_id]['data']: | |
| data = sessions[session_id]['data'].pop() | |
| yield f"data: {json.dumps(data)}\n\n" | |
| time.sleep(0.05) | |
| if session_id not in sessions: | |
| return f"Session {session_id} not found", 204 | |
| return Response(generate(), mimetype='text/event-stream') | |
| if __name__ == '__main__': | |
| print("CUDA Available:", torch.cuda.is_available()) | |
| app.run(host="0.0.0.0", port=PORT, debug=True, use_reloader=False) |