VietCat commited on
Commit
44fc75a
·
1 Parent(s): 95a780a

remove test existing model

Browse files
Files changed (2) hide show
  1. config.yaml +2 -2
  2. model.py +10 -16
config.yaml CHANGED
@@ -1,6 +1,6 @@
1
  model:
2
- path: 'nezahatkorkmaz/traffic-sign-detection' # YOLOv8 traffic sign model from HuggingFace (better trained)
3
- confidence_threshold: 0.35 # Minimum confidence for detections
4
 
5
  inference:
6
  box_color: (128, 0, 128) # Purple color for bounding boxes (BGR format)
 
1
  model:
2
+ path: 'VietCat/GTSRB-Model/models/GTSRB.pt' # Path to the YOLO model on Hugging Face Hub (will be downloaded automatically)
3
+ confidence_threshold: 0.15 # Minimum confidence for detections
4
 
5
  inference:
6
  box_color: (128, 0, 128) # Purple color for bounding boxes (BGR format)
model.py CHANGED
@@ -22,24 +22,18 @@ class TrafficSignDetector:
22
  # Load model from path
23
  model_path = config['model']['path']
24
 
25
- # If it's a HuggingFace path, use hf_hub_download
26
- if '/' in model_path and not model_path.endswith('.pt'):
27
- # HuggingFace model repo
28
- parts = model_path.split('/')
29
- if len(parts) == 2:
30
- repo_id = model_path
31
- # Try to download best.pt or runs/detect/.../weights/best.pt
32
- try:
33
- local_model_path = hf_hub_download(repo_id=repo_id, filename='best.pt')
34
- except:
35
- local_model_path = hf_hub_download(repo_id=repo_id, filename='runs/detect/train/weights/best.pt')
36
- self.model = YOLO(local_model_path)
37
- else:
38
- # Full path with filename
39
- repo_id = '/'.join(parts[:-1])
40
- file_path = parts[-1]
41
  local_model_path = hf_hub_download(repo_id=repo_id, filename=file_path)
42
  self.model = YOLO(local_model_path)
 
 
 
43
  else:
44
  # Local path or direct model path
45
  self.model = YOLO(model_path)
 
22
  # Load model from path
23
  model_path = config['model']['path']
24
 
25
+ # Handle HuggingFace paths
26
+ if '/' in model_path:
27
+ if model_path.endswith('.pt'):
28
+ # Full path with filename (e.g., VietCat/GTSRB-Model/models/GTSRB.pt)
29
+ parts = model_path.rsplit('/', 1)
30
+ repo_id = parts[0]
31
+ file_path = parts[1]
 
 
 
 
 
 
 
 
 
32
  local_model_path = hf_hub_download(repo_id=repo_id, filename=file_path)
33
  self.model = YOLO(local_model_path)
34
+ else:
35
+ # Just repo name (e.g., user/model-name)
36
+ self.model = YOLO(model_path)
37
  else:
38
  # Local path or direct model path
39
  self.model = YOLO(model_path)