Spaces:
Sleeping
Sleeping
| import re | |
| def generate_script_v8(dataset_code, task, model_size, epochs, batch_size): | |
| # Extract the necessary information from the dataset code | |
| api_key_match = re.search(r'api_key="(.*?)"', dataset_code) | |
| workspace_match = re.search(r'workspace\("([^"]+)"\)', dataset_code) | |
| project_name_match = re.search(r'project\("([^"]+)"\)', dataset_code) | |
| version_number_match = re.search(r'version\((\d+)\)', dataset_code) | |
| if not (api_key_match and workspace_match and project_name_match and version_number_match): | |
| return "Error: Could not extract necessary information from the dataset code." | |
| api_key = api_key_match.group(1) | |
| workspace = workspace_match.group(1) | |
| project_name = project_name_match.group(1) | |
| version_number = int(version_number_match.group(1)) | |
| # Determine the model type based on the selected task | |
| model_type = "seg" if task == "Segmentation" else "cls" | |
| # Generate the script | |
| script = f""" | |
| import yaml | |
| from ultralytics import YOLO | |
| from roboflow import Roboflow | |
| import logging | |
| import re | |
| import threading | |
| import time | |
| from io import StringIO | |
| # Set up logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| def auto_train(): | |
| log_stream = StringIO() | |
| log_handler = logging.StreamHandler(log_stream) | |
| log_handler.setLevel(logging.INFO) | |
| logger.addHandler(log_handler) | |
| try: | |
| api_key = "{api_key}" | |
| workspace = "{workspace}" | |
| project_name = "{project_name}" | |
| version_number = {version_number} | |
| # Load the Roboflow dataset | |
| rf = Roboflow(api_key=api_key) | |
| project = rf.workspace(workspace).project(project_name) | |
| version = project.version(version_number) | |
| dataset = version.download("yolov8") | |
| # Modify the data structure | |
| yaml_file_path = f'{{dataset.location}}/data.yaml' | |
| with open(yaml_file_path, 'r') as file: | |
| data = yaml.safe_load(file) | |
| data['val'] = '../valid/images' | |
| data['test'] = '../test/images' | |
| data['train'] = '../train/images' | |
| with open(yaml_file_path, 'w') as file: | |
| yaml.safe_dump(data, file) | |
| # Determine the model name based on the selected size and task | |
| model_name = f"yolov8{model_size}-{model_type}.pt" | |
| # Load and train the model | |
| model = YOLO(model_name) | |
| model.info() | |
| # Function to read logs in real-time and update the Streamlit textbox | |
| def update_logs(): | |
| while getattr(threading.currentThread(), "do_run", True): | |
| time.sleep(1) | |
| log_stream.seek(0) | |
| print(log_stream.read()) | |
| # Start a thread to update logs in real-time | |
| log_thread = threading.Thread(target=update_logs) | |
| log_thread.start() | |
| results = model.train(data=yaml_file_path, epochs={epochs}, imgsz=640, batch={batch_size}) | |
| # Stop the log update thread | |
| logger.removeHandler(log_handler) | |
| log_thread.do_run = False | |
| log_thread.join() | |
| # Return the result path and logs | |
| log_stream.seek(0) | |
| log_output = log_stream.read() | |
| print("Results Directory:", results.results_dir) | |
| print("Final Training Logs:", log_output) | |
| except Exception as e: | |
| logger.error(f"An error occurred: {{e}}") | |
| log_stream.seek(0) | |
| log_output = log_stream.read() | |
| print(f"Error: {{e}}") | |
| print(log_output) | |
| finally: | |
| logger.removeHandler(log_handler) | |
| if __name__ == "__main__": | |
| auto_train() | |
| """ | |
| return script | |
| def generate_script_v9(dataset_code, task, model_size, epochs, batch_size): | |
| # Extract the necessary information from the dataset code | |
| api_key_match = re.search(r'api_key="(.*?)"', dataset_code) | |
| workspace_match = re.search(r'workspace\("([^"]+)"\)', dataset_code) | |
| project_name_match = re.search(r'project\("([^"]+)"\)', dataset_code) | |
| version_number_match = re.search(r'version\((\d+)\)', dataset_code) | |
| if not (api_key_match and workspace_match and project_name_match and version_number_match): | |
| return "Error: Could not extract necessary information from the dataset code." | |
| api_key = api_key_match.group(1) | |
| workspace = workspace_match.group(1) | |
| project_name = project_name_match.group(1) | |
| version_number = int(version_number_match.group(1)) | |
| # Determine the model name based on the selected size and task | |
| if task == "Segmentation": | |
| model_name = f"gelan-c-seg.pt" if model_size == "c" else f"yolov9-{model_size}-seg.pt" | |
| else: | |
| model_name = f"yolov9-{model_size}.pt" | |
| # Generate the script | |
| script = f""" | |
| !pip install roboflow | |
| from roboflow import Roboflow | |
| rf = Roboflow(api_key="{api_key}") | |
| project = rf.workspace("{workspace}").project("{project_name}") | |
| version = project.version({version_number}) | |
| dataset = version.download("yolov9") | |
| !python train.py \\ | |
| --batch {batch_size} --epochs {epochs} --img 640 --device 0 --min-items 0 --close-mosaic 15 \\ | |
| --data {{dataset.location}}/data.yaml \\ | |
| --weights {{HOME}}/weights/{model_name} \\ | |
| --cfg models/detect/{model_name.split('.')[0]}.yaml \\ | |
| --hyp hyp.scratch-high.yaml | |
| """ | |
| return script | |
| import streamlit as st | |
| st.title("Auto Train Script Generator") | |
| st.write("Generate a YOLO training script using a Roboflow dataset") | |
| tab1, tab2 = st.tabs(["YOLOv8", "YOLOv9"]) | |
| with tab1: | |
| st.subheader("YOLOv8 Script Generator") | |
| dataset_code = st.text_input("Roboflow Dataset Code", key="dataset_code_v8", placeholder="Paste your Roboflow dataset code here") | |
| task = st.selectbox("Task", ["Object Detection", "Segmentation"], index=0, key="task_v8") | |
| model_size = st.selectbox("Model Size", ["n", "s", "m", "l", "x"], index=0, key="model_size_v8") | |
| epochs = st.selectbox("Epochs", [50, 100, 200, 300, 400, 500], index=3, key="epochs_v8") | |
| batch_size = st.selectbox("Batch Size", [1, 2, 4, 8, 16, 32], index=0, key="batch_size_v8") | |
| if st.button("Generate YOLOv8 Script"): | |
| script = generate_script_v8(dataset_code, task, model_size, epochs, batch_size) | |
| st.code(script, language="python") | |
| with tab2: | |
| st.subheader("YOLOv9 Script Generator") | |
| dataset_code = st.text_input("Roboflow Dataset Code", key="dataset_code_v9", placeholder="Paste your Roboflow dataset code here") | |
| task = st.selectbox("Task", ["Object Detection", "Segmentation"], index=0, key="task_v9") | |
| model_size = st.selectbox("Model Size", ["t", "s", "m", "c", "e"], index=0, key="model_size_v9") | |
| epochs = st.selectbox("Epochs", [50, 100, 200, 300, 400, 500], index=3, key="epochs_v9") | |
| batch_size = st.selectbox("Batch Size", [1, 2, 4, 8, 16, 32], index=0, key="batch_size_v9") | |
| if st.button("Generate YOLOv9 Script"): | |
| script = generate_script_v9(dataset_code, task, model_size, epochs, batch_size) | |
| st.code(script, language="python") |