shaheerawan3 commited on
Commit
e5c0b33
·
verified ·
1 Parent(s): 6b66502

Update train.py

Browse files
Files changed (1) hide show
  1. train.py +27 -34
train.py CHANGED
@@ -1,19 +1,14 @@
 
1
  import yaml
2
  from ultralytics import YOLO
3
- import os
4
- import gdown
5
-
6
- def download_pretrained_weights():
7
- """Download pre-trained YOLOv8 weights"""
8
- url = 'https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov8x.pt'
9
- gdown.download(url, 'yolov8x.pt', quiet=False)
10
 
11
  def create_dataset_yaml():
12
- """Create YAML file for dataset configuration"""
13
  data_yaml = {
14
- 'path': 'dataset', # dataset root dir
15
- 'train': 'images/train', # train images
16
- 'val': 'images/val', # val images
17
  'names': {
18
  0: 'Commercial Airliner',
19
  1: 'Military Fighter',
@@ -24,30 +19,28 @@ def create_dataset_yaml():
24
  6: 'Unknown'
25
  }
26
  }
27
-
 
 
 
 
28
  with open('dataset.yaml', 'w') as f:
29
  yaml.dump(data_yaml, f)
30
 
31
  def train_model():
32
- """Train the YOLOv8 model"""
33
- # Download pre-trained weights if not exists
34
- if not os.path.exists('yolov8x.pt'):
35
- download_pretrained_weights()
36
-
37
- # Create dataset configuration
38
- create_dataset_yaml()
39
-
40
- # Initialize model
41
- model = YOLO('yolov8x.pt')
42
-
43
- # Train the model
44
- results = model.train(
45
- data='dataset.yaml',
46
- epochs=100,
47
- imgsz=640,
48
- batch=16,
49
- name='aircraft_detection'
50
- )
51
-
52
- # Export the model
53
- model.export(format='onnx')
 
1
+ import os
2
  import yaml
3
  from ultralytics import YOLO
4
+ import streamlit as st
 
 
 
 
 
 
5
 
6
  def create_dataset_yaml():
7
+ """Create dataset.yaml configuration."""
8
  data_yaml = {
9
+ 'path': 'dataset',
10
+ 'train': 'images/train',
11
+ 'val': 'images/val',
12
  'names': {
13
  0: 'Commercial Airliner',
14
  1: 'Military Fighter',
 
19
  6: 'Unknown'
20
  }
21
  }
22
+ os.makedirs('dataset/images/train', exist_ok=True)
23
+ os.makedirs('dataset/images/val', exist_ok=True)
24
+ os.makedirs('dataset/labels/train', exist_ok=True)
25
+ os.makedirs('dataset/labels/val', exist_ok=True)
26
+
27
  with open('dataset.yaml', 'w') as f:
28
  yaml.dump(data_yaml, f)
29
 
30
  def train_model():
31
+ """Train the custom model."""
32
+ try:
33
+ create_dataset_yaml()
34
+ model = YOLO('yolov8n.pt')
35
+ model.train(
36
+ data='dataset.yaml',
37
+ epochs=100,
38
+ imgsz=640,
39
+ batch=16,
40
+ name='aircraft_detection'
41
+ )
42
+ model.export(format='onnx')
43
+ return True
44
+ except Exception as e:
45
+ st.error(f"Error during training: {str(e)}")
46
+ return False