jukrapopk's picture
update requirements and update code to use torch cuda
4ae095d
import json
import logging
from PyThreadKiller import PyThreadKiller
from ultralytics import YOLO
from flask import Flask, request, Response
from flask_cors import CORS
import threading
from utils import start_html_stream, get_mjpeg_frames, get_webpage_frames
import cv2
from ultralytics import solutions
import time
import torch
logging.getLogger('ultralytics').setLevel(logging.WARNING)
PORT = 7860
app = Flask(__name__)
CORS(app, resources={r"/*": {"origins": "*"}})
sessions = {}
def process_frames(stream_url, model, should_track, is_locked, frames_per_second, classes, regions, output_frames, output_data):
cap = cv2.VideoCapture(stream_url)
get_frames = get_mjpeg_frames if cap.isOpened() else get_webpage_frames
cap.release()
counter = None
if model == 'regions':
counter = solutions.RegionCounter(
show=False,
region=regions,
model='models/yolo11n.pt',
classes=classes,
persist=should_track
)
frame_interval = 1 / frames_per_second
for frame in get_frames(stream_url):
start_time = time.time()
with is_locked:
if model is not None and len(classes) > 0:
if model == 'regions':
region_tracked_frame = counter.count(frame)
output_data.append({
'objects': [
{
'box': box.xyxy.tolist(),
'label': counter.model.names[int(box.cls)],
'confidence': box.conf.item(),
}
for box in counter.track_data
],
# Always return 0 : need to find out
# 'counter': counter.counting_regions.pop()['counts']
})
output_frames.insert(0, region_tracked_frame)
else:
annotated_frame = model.track(frame, classes=classes) if should_track else model.predict(frame, classes=classes)
output_data.append({
'objects': [
{
'box': box.xyxy.tolist(),
'label': model.names[int(box.cls)],
'confidence': box.conf.item()
}
for box in annotated_frame[0].boxes
]
})
output_frames.insert(0, annotated_frame[0].plot())
else:
output_frames.insert(0, frame)
elapsed_time = time.time() - start_time
time.sleep(max(0, frame_interval - elapsed_time))
@app.route('/')
def html_stream():
stream_url = request.args.get('url', '')
stream_url = stream_url if stream_url.startswith(('http://', 'https://')) else 'http://' + stream_url
model_type = request.args.get('model', 'detection')
should_track = request.args.get('tracking', 'false').lower() == 'true'
classes = request.args.get('classes', '')
classes = list(map(int, classes.split(','))) if classes else []
regions = request.args.get('regions', '')
regions = {str(i+1): [tuple(map(int, point.split(','))) for point in part.split('~')] for i, part in enumerate(regions.split('_'))} if regions else None
session_id = f"{stream_url}_{model_type}_{should_track}_{''.join(map(str, classes))}"
model = {
'detection': YOLO('models/yolo11n.pt'),
'segmentation': YOLO('models/yolo11n-seg.pt'),
'pose': YOLO('models/yolo11n-pose.pt'),
'orientation': YOLO('models/yolo11n-obb.pt'),
'classification': YOLO('models/yolo11n-cls.pt'),
'regions': 'regions' if regions else YOLO('models/yolo11n.pt'),
'preview': None
}.get(model_type, None)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
if session_id not in sessions:
sessions[session_id] = {
'frame': [None],
'lock': threading.Lock(),
'threads': [],
'data': []
}
process_thread = PyThreadKiller(target=process_frames, args=(stream_url, model, should_track, sessions[session_id]['lock'], 10, classes, regions, sessions[session_id]['frame'], sessions[session_id]['data']))
process_thread.daemon = True
sessions[session_id]['thread'] = process_thread
process_thread.start()
return start_html_stream(sessions[session_id]['frame'], sessions[session_id]['lock'])
@app.route('/kill')
def kill_session():
stream_url = request.args.get('url', '')
stream_url = stream_url if stream_url.startswith(('http://', 'https://')) else 'http://' + stream_url
model_type = request.args.get('model', 'detection')
should_track = request.args.get('tracking', 'false').lower() == 'true'
classes = request.args.get('classes', '')
classes = list(map(int, classes.split(','))) if classes else []
session_id = f"{stream_url}_{model_type}_{should_track}_{''.join(map(str, classes))}"
if session_id in sessions:
with sessions[session_id]['lock']:
sessions[session_id]['thread'].kill()
del sessions[session_id]
return f"Session {session_id} killed", 200
else:
return f"Session {session_id} not found", 204
@app.route('/data')
def data_stream():
stream_url = request.args.get('url', '')
stream_url = stream_url if stream_url.startswith(('http://', 'https://')) else 'http://' + stream_url
model_type = request.args.get('model', 'detection')
should_track = request.args.get('tracking', 'false').lower() == 'true'
classes = request.args.get('classes', '')
classes = list(map(int, classes.split(','))) if classes else []
session_id = f"{stream_url}_{model_type}_{should_track}_{''.join(map(str, classes))}"
def generate():
while True:
with sessions[session_id]['lock']:
if sessions[session_id]['data']:
data = sessions[session_id]['data'].pop()
yield f"data: {json.dumps(data)}\n\n"
time.sleep(0.05)
if session_id not in sessions:
return f"Session {session_id} not found", 204
return Response(generate(), mimetype='text/event-stream')
if __name__ == '__main__':
print("CUDA Available:", torch.cuda.is_available())
app.run(host="0.0.0.0", port=PORT, debug=True, use_reloader=False)