import json import logging from PyThreadKiller import PyThreadKiller from ultralytics import YOLO from flask import Flask, request, Response from flask_cors import CORS import threading from utils import start_html_stream, get_mjpeg_frames, get_webpage_frames import cv2 from ultralytics import solutions import time import torch logging.getLogger('ultralytics').setLevel(logging.WARNING) PORT = 7860 app = Flask(__name__) CORS(app, resources={r"/*": {"origins": "*"}}) sessions = {} def process_frames(stream_url, model, should_track, is_locked, frames_per_second, classes, regions, output_frames, output_data): cap = cv2.VideoCapture(stream_url) get_frames = get_mjpeg_frames if cap.isOpened() else get_webpage_frames cap.release() counter = None if model == 'regions': counter = solutions.RegionCounter( show=False, region=regions, model='models/yolo11n.pt', classes=classes, persist=should_track ) frame_interval = 1 / frames_per_second for frame in get_frames(stream_url): start_time = time.time() with is_locked: if model is not None and len(classes) > 0: if model == 'regions': region_tracked_frame = counter.count(frame) output_data.append({ 'objects': [ { 'box': box.xyxy.tolist(), 'label': counter.model.names[int(box.cls)], 'confidence': box.conf.item(), } for box in counter.track_data ], # Always return 0 : need to find out # 'counter': counter.counting_regions.pop()['counts'] }) output_frames.insert(0, region_tracked_frame) else: annotated_frame = model.track(frame, classes=classes) if should_track else model.predict(frame, classes=classes) output_data.append({ 'objects': [ { 'box': box.xyxy.tolist(), 'label': model.names[int(box.cls)], 'confidence': box.conf.item() } for box in annotated_frame[0].boxes ] }) output_frames.insert(0, annotated_frame[0].plot()) else: output_frames.insert(0, frame) elapsed_time = time.time() - start_time time.sleep(max(0, frame_interval - elapsed_time)) @app.route('/') def html_stream(): stream_url = request.args.get('url', '') stream_url = stream_url if stream_url.startswith(('http://', 'https://')) else 'http://' + stream_url model_type = request.args.get('model', 'detection') should_track = request.args.get('tracking', 'false').lower() == 'true' classes = request.args.get('classes', '') classes = list(map(int, classes.split(','))) if classes else [] regions = request.args.get('regions', '') regions = {str(i+1): [tuple(map(int, point.split(','))) for point in part.split('~')] for i, part in enumerate(regions.split('_'))} if regions else None session_id = f"{stream_url}_{model_type}_{should_track}_{''.join(map(str, classes))}" model = { 'detection': YOLO('models/yolo11n.pt'), 'segmentation': YOLO('models/yolo11n-seg.pt'), 'pose': YOLO('models/yolo11n-pose.pt'), 'orientation': YOLO('models/yolo11n-obb.pt'), 'classification': YOLO('models/yolo11n-cls.pt'), 'regions': 'regions' if regions else YOLO('models/yolo11n.pt'), 'preview': None }.get(model_type, None) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model.to(device) if session_id not in sessions: sessions[session_id] = { 'frame': [None], 'lock': threading.Lock(), 'threads': [], 'data': [] } process_thread = PyThreadKiller(target=process_frames, args=(stream_url, model, should_track, sessions[session_id]['lock'], 10, classes, regions, sessions[session_id]['frame'], sessions[session_id]['data'])) process_thread.daemon = True sessions[session_id]['thread'] = process_thread process_thread.start() return start_html_stream(sessions[session_id]['frame'], sessions[session_id]['lock']) @app.route('/kill') def kill_session(): stream_url = request.args.get('url', '') stream_url = stream_url if stream_url.startswith(('http://', 'https://')) else 'http://' + stream_url model_type = request.args.get('model', 'detection') should_track = request.args.get('tracking', 'false').lower() == 'true' classes = request.args.get('classes', '') classes = list(map(int, classes.split(','))) if classes else [] session_id = f"{stream_url}_{model_type}_{should_track}_{''.join(map(str, classes))}" if session_id in sessions: with sessions[session_id]['lock']: sessions[session_id]['thread'].kill() del sessions[session_id] return f"Session {session_id} killed", 200 else: return f"Session {session_id} not found", 204 @app.route('/data') def data_stream(): stream_url = request.args.get('url', '') stream_url = stream_url if stream_url.startswith(('http://', 'https://')) else 'http://' + stream_url model_type = request.args.get('model', 'detection') should_track = request.args.get('tracking', 'false').lower() == 'true' classes = request.args.get('classes', '') classes = list(map(int, classes.split(','))) if classes else [] session_id = f"{stream_url}_{model_type}_{should_track}_{''.join(map(str, classes))}" def generate(): while True: with sessions[session_id]['lock']: if sessions[session_id]['data']: data = sessions[session_id]['data'].pop() yield f"data: {json.dumps(data)}\n\n" time.sleep(0.05) if session_id not in sessions: return f"Session {session_id} not found", 204 return Response(generate(), mimetype='text/event-stream') if __name__ == '__main__': print("CUDA Available:", torch.cuda.is_available()) app.run(host="0.0.0.0", port=PORT, debug=True, use_reloader=False)