dgarrett-synaptics commited on
Commit
8d8e9d5
·
verified ·
1 Parent(s): f5c7ad3

Create training.py

Browse files
Files changed (1) hide show
  1. training.py +118 -0
training.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import spaces
3
+ import sys
4
+ import os
5
+ from datasets import load_dataset
6
+ from huggingface_hub import snapshot_download
7
+ import tempfile
8
+
9
+ temp_dir = tempfile.TemporaryDirectory()
10
+ print(f'Creating {temp_dir.name}')
11
+
12
+ try:
13
+ from synet.backends import get_backend
14
+ get_backend('ultralytics').patch()
15
+ from synet.backends import get_backend
16
+ from ultralytics import YOLO
17
+ except ImportError:
18
+ import subprocess
19
+ print('Installing synet package')
20
+ subprocess.check_call([sys.executable, "-m", "pip", "install", "./synet_package[ultra]"])
21
+ # Import the file
22
+ from synet.backends import get_backend
23
+ from synet.backends import get_backend
24
+ from ultralytics import YOLO
25
+
26
+
27
+ # Setup the backend processing
28
+ backend = get_backend('ultralytics')
29
+ backend.patch()
30
+
31
+ dataset_cfg = None
32
+
33
+
34
+ @spaces.GPU
35
+ def greet(name):
36
+ return "Hello " + name + "!!"
37
+
38
+ def get_dataset():
39
+ global dataset_cfg
40
+
41
+ dataset_name = 'Ultralytics/COCO8'
42
+
43
+ dataset_path = f'{temp_dir.name}/COCO8'
44
+ print(f'Writing to dataset {dataset_path}')
45
+ snapshot_download(repo_id=dataset_name, repo_type='dataset', local_dir=dataset_path)
46
+
47
+ #dataset = load_dataset(dataset_name)
48
+ #print(dataset)
49
+ #dataset.save_to_disk(dataset_path)
50
+
51
+ files = os.listdir(dataset_path)
52
+ local_files = 'local_files in {dataset_path}: '
53
+ for file in files:
54
+ local_files += f'{file} '
55
+ print(local_files)
56
+
57
+
58
+
59
+ dataset_cfg = f'{dataset_path}/dataset.yaml'
60
+ print(f'Writing to dataset {dataset_cfg}')
61
+ with open(dataset_cfg, 'r') as fp:
62
+ contents = fp.read()
63
+ print(contents)
64
+
65
+ return f"Loading the dataset in {temp_dir.name}"
66
+
67
+ @spaces.GPU
68
+ def run_training():
69
+
70
+ model_cfg = 'synet_package/synet/zoo/ultralytics/sabre-detect-vga.yaml'
71
+ model = YOLO(model=model_cfg)
72
+ print(f'Loading model_cfg {model_cfg}')
73
+ print(model)
74
+ image_size = (480,640)
75
+
76
+ # Run the initial training
77
+ project_path = f'{temp_dir.name}/runs'
78
+ print(f'Run the training in {project_path}')
79
+ model.train(data=dataset_cfg, project=project_path, name='example_train')
80
+
81
+ # PRint teh files
82
+ files = os.listdir(project_path)
83
+ local_files = 'local_files: '
84
+ for file in files:
85
+ local_files += f'{file} '
86
+
87
+ return f"Done with training: {local_files}"
88
+ #get_tflite(backend, image_size, 'runs/example_train/weights/best.pt',
89
+ # 'coco.yaml', 500, 3, {})
90
+
91
+ # Run the ultralytics
92
+ #def run_ultralytics():
93
+ #
94
+ # files = os.listdir()
95
+ # local_files = 'local_files: '
96
+ # for file in files:
97
+ # local_files += f'{file} '
98
+ #
99
+ # return local_files
100
+
101
+ with gr.Blocks() as demo:
102
+ text1 = gr.Markdown("Starting to test SyNet")
103
+ text2 = gr.Markdown("")
104
+
105
+ load_btn = gr.Button("Load dataset")
106
+ load_text = gr.Markdown("")
107
+
108
+ train_btn = gr.Button("Train")
109
+ train_text = gr.Markdown("")
110
+
111
+ load_btn.click(get_dataset, inputs=None, outputs=[load_text])
112
+ train_btn.click(run_training, inputs=None, outputs=[train_text])
113
+
114
+ #demo.load(run_ultralytics, inputs=None, outputs=[text2])
115
+
116
+ if __name__ == "__main__":
117
+ #demo.launch(share=True)
118
+ demo.launch()