jukrapopk commited on
Commit
4ae095d
·
1 Parent(s): e2e8d42

update requirements and update code to use torch cuda

Browse files
Files changed (2) hide show
  1. requirements.txt +4 -1
  2. src/app.py +13 -11
requirements.txt CHANGED
@@ -5,4 +5,7 @@ flask-cors==5.0.1
5
  PyThreadKiller==3.0.6
6
  vidgear==0.3.3
7
  selenium==4.32.0
8
- webdriver_manager==4.0.2
 
 
 
 
5
  PyThreadKiller==3.0.6
6
  vidgear==0.3.3
7
  selenium==4.32.0
8
+ webdriver_manager==4.0.2
9
+ torch==2.7.0+cu126
10
+ torchaudio==2.7.0+cu126
11
+ torchvision==0.22.0+cu126
src/app.py CHANGED
@@ -9,12 +9,8 @@ from utils import start_html_stream, get_mjpeg_frames, get_webpage_frames
9
  import cv2
10
  from ultralytics import solutions
11
  import time
 
12
 
13
- yolo_detection_model = YOLO('models/yolo11n.pt')
14
- yolo_segmentation_model = YOLO('models/yolo11n-seg.pt')
15
- yolo_pose_model = YOLO('models/yolo11n-pose.pt')
16
- yolo_orientation_model = YOLO('models/yolo11n-obb.pt')
17
- yolo_classification_model = YOLO('models/yolo11n-cls.pt')
18
  logging.getLogger('ultralytics').setLevel(logging.WARNING)
19
 
20
  PORT = 7860
@@ -96,14 +92,17 @@ def html_stream():
96
  session_id = f"{stream_url}_{model_type}_{should_track}_{''.join(map(str, classes))}"
97
 
98
  model = {
99
- 'detection': yolo_detection_model,
100
- 'segmentation': yolo_segmentation_model,
101
- 'pose': yolo_pose_model,
102
- 'orientation': yolo_orientation_model,
103
- 'classification': yolo_classification_model,
104
- 'regions': 'regions' if regions else yolo_detection_model,
105
  'preview': None
106
  }.get(model_type, None)
 
 
 
107
 
108
  if session_id not in sessions:
109
  sessions[session_id] = {
@@ -164,4 +163,7 @@ def data_stream():
164
  return f"Session {session_id} not found", 204
165
  return Response(generate(), mimetype='text/event-stream')
166
 
 
 
 
167
  app.run(host="0.0.0.0", port=PORT, debug=True, use_reloader=False)
 
9
  import cv2
10
  from ultralytics import solutions
11
  import time
12
+ import torch
13
 
 
 
 
 
 
14
  logging.getLogger('ultralytics').setLevel(logging.WARNING)
15
 
16
  PORT = 7860
 
92
  session_id = f"{stream_url}_{model_type}_{should_track}_{''.join(map(str, classes))}"
93
 
94
  model = {
95
+ 'detection': YOLO('models/yolo11n.pt'),
96
+ 'segmentation': YOLO('models/yolo11n-seg.pt'),
97
+ 'pose': YOLO('models/yolo11n-pose.pt'),
98
+ 'orientation': YOLO('models/yolo11n-obb.pt'),
99
+ 'classification': YOLO('models/yolo11n-cls.pt'),
100
+ 'regions': 'regions' if regions else YOLO('models/yolo11n.pt'),
101
  'preview': None
102
  }.get(model_type, None)
103
+
104
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
105
+ model.to(device)
106
 
107
  if session_id not in sessions:
108
  sessions[session_id] = {
 
163
  return f"Session {session_id} not found", 204
164
  return Response(generate(), mimetype='text/event-stream')
165
 
166
+ if __name__ == '__main__':
167
+ print("CUDA Available:", torch.cuda.is_available())
168
+
169
  app.run(host="0.0.0.0", port=PORT, debug=True, use_reloader=False)