Spaces:
Build error
Build error
update requirements and update code to use torch cuda
Browse files- requirements.txt +4 -1
- src/app.py +13 -11
requirements.txt
CHANGED
|
@@ -5,4 +5,7 @@ flask-cors==5.0.1
|
|
| 5 |
PyThreadKiller==3.0.6
|
| 6 |
vidgear==0.3.3
|
| 7 |
selenium==4.32.0
|
| 8 |
-
webdriver_manager==4.0.2
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
PyThreadKiller==3.0.6
|
| 6 |
vidgear==0.3.3
|
| 7 |
selenium==4.32.0
|
| 8 |
+
webdriver_manager==4.0.2
|
| 9 |
+
torch==2.7.0+cu126
|
| 10 |
+
torchaudio==2.7.0+cu126
|
| 11 |
+
torchvision==0.22.0+cu126
|
src/app.py
CHANGED
|
@@ -9,12 +9,8 @@ from utils import start_html_stream, get_mjpeg_frames, get_webpage_frames
|
|
| 9 |
import cv2
|
| 10 |
from ultralytics import solutions
|
| 11 |
import time
|
|
|
|
| 12 |
|
| 13 |
-
yolo_detection_model = YOLO('models/yolo11n.pt')
|
| 14 |
-
yolo_segmentation_model = YOLO('models/yolo11n-seg.pt')
|
| 15 |
-
yolo_pose_model = YOLO('models/yolo11n-pose.pt')
|
| 16 |
-
yolo_orientation_model = YOLO('models/yolo11n-obb.pt')
|
| 17 |
-
yolo_classification_model = YOLO('models/yolo11n-cls.pt')
|
| 18 |
logging.getLogger('ultralytics').setLevel(logging.WARNING)
|
| 19 |
|
| 20 |
PORT = 7860
|
|
@@ -96,14 +92,17 @@ def html_stream():
|
|
| 96 |
session_id = f"{stream_url}_{model_type}_{should_track}_{''.join(map(str, classes))}"
|
| 97 |
|
| 98 |
model = {
|
| 99 |
-
'detection':
|
| 100 |
-
'segmentation':
|
| 101 |
-
'pose':
|
| 102 |
-
'orientation':
|
| 103 |
-
'classification':
|
| 104 |
-
'regions': 'regions' if regions else
|
| 105 |
'preview': None
|
| 106 |
}.get(model_type, None)
|
|
|
|
|
|
|
|
|
|
| 107 |
|
| 108 |
if session_id not in sessions:
|
| 109 |
sessions[session_id] = {
|
|
@@ -164,4 +163,7 @@ def data_stream():
|
|
| 164 |
return f"Session {session_id} not found", 204
|
| 165 |
return Response(generate(), mimetype='text/event-stream')
|
| 166 |
|
|
|
|
|
|
|
|
|
|
| 167 |
app.run(host="0.0.0.0", port=PORT, debug=True, use_reloader=False)
|
|
|
|
| 9 |
import cv2
|
| 10 |
from ultralytics import solutions
|
| 11 |
import time
|
| 12 |
+
import torch
|
| 13 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
logging.getLogger('ultralytics').setLevel(logging.WARNING)
|
| 15 |
|
| 16 |
PORT = 7860
|
|
|
|
| 92 |
session_id = f"{stream_url}_{model_type}_{should_track}_{''.join(map(str, classes))}"
|
| 93 |
|
| 94 |
model = {
|
| 95 |
+
'detection': YOLO('models/yolo11n.pt'),
|
| 96 |
+
'segmentation': YOLO('models/yolo11n-seg.pt'),
|
| 97 |
+
'pose': YOLO('models/yolo11n-pose.pt'),
|
| 98 |
+
'orientation': YOLO('models/yolo11n-obb.pt'),
|
| 99 |
+
'classification': YOLO('models/yolo11n-cls.pt'),
|
| 100 |
+
'regions': 'regions' if regions else YOLO('models/yolo11n.pt'),
|
| 101 |
'preview': None
|
| 102 |
}.get(model_type, None)
|
| 103 |
+
|
| 104 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 105 |
+
model.to(device)
|
| 106 |
|
| 107 |
if session_id not in sessions:
|
| 108 |
sessions[session_id] = {
|
|
|
|
| 163 |
return f"Session {session_id} not found", 204
|
| 164 |
return Response(generate(), mimetype='text/event-stream')
|
| 165 |
|
| 166 |
+
if __name__ == '__main__':
|
| 167 |
+
print("CUDA Available:", torch.cuda.is_available())
|
| 168 |
+
|
| 169 |
app.run(host="0.0.0.0", port=PORT, debug=True, use_reloader=False)
|