Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import json | |
| import os | |
| from pathlib import Path | |
| from PIL import Image | |
| import shutil | |
| from ultralytics import YOLO | |
| import requests | |
| MODELS_DIR = "models" | |
| MODELS_INFO_FILE = "models_info.json" | |
| TEMP_DIR = "temp" | |
| OUTPUT_DIR = "outputs" | |
| def download_file(url, dest_path): | |
| """ | |
| Download a file from a URL to the destination path. | |
| Args: | |
| url (str): The URL to download from. | |
| dest_path (str): The local path to save the file. | |
| Returns: | |
| bool: True if download succeeded, False otherwise. | |
| """ | |
| try: | |
| response = requests.get(url, stream=True) | |
| response.raise_for_status() | |
| with open(dest_path, 'wb') as f: | |
| for chunk in response.iter_content(chunk_size=8192): | |
| f.write(chunk) | |
| print(f"Downloaded {url} to {dest_path}.") | |
| return True | |
| except Exception as e: | |
| print(f"Failed to download {url}. Error: {e}") | |
| return False | |
| def load_models(models_dir=MODELS_DIR, info_file=MODELS_INFO_FILE): | |
| """ | |
| Load YOLO models and their information from the specified directory and JSON file. | |
| Downloads models if they are not already present. | |
| Args: | |
| models_dir (str): Path to the models directory. | |
| info_file (str): Path to the JSON file containing model info. | |
| Returns: | |
| dict: A dictionary of models and their associated information. | |
| """ | |
| with open(info_file, 'r') as f: | |
| models_info = json.load(f) | |
| models = {} | |
| for model_info in models_info: | |
| model_name = model_info['model_name'] | |
| display_name = model_info.get('display_name', model_name) | |
| model_dir = os.path.join(models_dir, model_name) | |
| os.makedirs(model_dir, exist_ok=True) | |
| model_path = os.path.join(model_dir, f"{model_name}.pt") | |
| download_url = model_info['download_url'] | |
| if not os.path.isfile(model_path): | |
| print(f"Model '{display_name}' not found locally. Downloading from {download_url}...") | |
| success = download_file(download_url, model_path) | |
| if not success: | |
| print(f"Skipping model '{display_name}' due to download failure.") | |
| continue | |
| try: | |
| model = YOLO(model_path) | |
| models[model_name] = { | |
| 'display_name': display_name, | |
| 'model': model, | |
| 'info': model_info | |
| } | |
| print(f"Loaded model '{display_name}' from '{model_path}'.") | |
| except Exception as e: | |
| print(f"Error loading model '{display_name}': {e}") | |
| return models | |
| def get_model_info(model_info): | |
| """ | |
| Retrieve formatted model information for display. | |
| Args: | |
| model_info (dict): The model's information dictionary. | |
| Returns: | |
| str: A formatted string containing model details. | |
| """ | |
| info = model_info | |
| class_ids = info.get('class_ids', {}) | |
| class_image_counts = info.get('class_image_counts', {}) | |
| datasets_used = info.get('datasets_used', []) | |
| class_ids_formatted = "\n".join([f"{cid}: {cname}" for cid, cname in class_ids.items()]) | |
| class_image_counts_formatted = "\n".join([f"{cname}: {count}" for cname, count in class_image_counts.items()]) | |
| datasets_used_formatted = "\n".join([f"- {dataset}" for dataset in datasets_used]) | |
| info_text = ( | |
| f"**{info.get('display_name', 'Model Name')}**\n\n" | |
| f"**Architecture:** {info.get('architecture', 'N/A')}\n\n" | |
| f"**Training Epochs:** {info.get('training_epochs', 'N/A')}\n\n" | |
| f"**Batch Size:** {info.get('batch_size', 'N/A')}\n\n" | |
| f"**Optimizer:** {info.get('optimizer', 'N/A')}\n\n" | |
| f"**Learning Rate:** {info.get('learning_rate', 'N/A')}\n\n" | |
| f"**Data Augmentation Level:** {info.get('data_augmentation_level', 'N/A')}\n\n" | |
| f"**mAP@0.5:** {info.get('mAP_score', 'N/A')}\n\n" | |
| f"**Number of Images Trained On:** {info.get('num_images', 'N/A')}\n\n" | |
| f"**Class IDs:**\n{class_ids_formatted}\n\n" | |
| f"**Datasets Used:**\n{datasets_used_formatted}\n\n" | |
| f"**Class Image Counts:**\n{class_image_counts_formatted}" | |
| ) | |
| return info_text | |
| def predict_image(model_name, image, confidence, models): | |
| """ | |
| Perform prediction on an uploaded image using the selected YOLO model. | |
| Args: | |
| model_name (str): The name of the selected model. | |
| image (PIL.Image.Image): The uploaded image. | |
| confidence (float): The confidence threshold for detections. | |
| models (dict): The dictionary containing models and their info. | |
| Returns: | |
| tuple: A status message, the processed image, and the path to the output image. | |
| """ | |
| model_entry = models.get(model_name, {}) | |
| model = model_entry.get('model', None) | |
| if not model: | |
| return "Error: Model not found.", None, None | |
| try: | |
| os.makedirs(TEMP_DIR, exist_ok=True) | |
| os.makedirs(OUTPUT_DIR, exist_ok=True) | |
| input_image_path = os.path.join(TEMP_DIR, f"{model_name}_input_image.jpg") | |
| image.save(input_image_path) | |
| results = model(input_image_path, save=True, save_txt=False, conf=confidence) | |
| latest_run = sorted(Path("runs/detect").glob("predict*"), key=os.path.getmtime)[-1] | |
| output_image_path = os.path.join(latest_run, Path(input_image_path).name) | |
| if not os.path.isfile(output_image_path): | |
| output_image_path = results[0].save()[0] | |
| final_output_path = os.path.join(OUTPUT_DIR, f"{model_name}_output_image.jpg") | |
| shutil.copy(output_image_path, final_output_path) | |
| output_image = Image.open(final_output_path) | |
| return "✅ Prediction completed successfully.", output_image, final_output_path | |
| except Exception as e: | |
| return f"❌ Error during prediction: {str(e)}", None, None | |
| def main(): | |
| models = load_models() | |
| if not models: | |
| print("No models loaded. Please check your models_info.json and model URLs.") | |
| return | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# 🧪 YOLOv11 Model Tester") | |
| gr.Markdown( | |
| """ | |
| Upload images to test different YOLOv11 models. Select a model from the dropdown to see its details. | |
| """ | |
| ) | |
| with gr.Row(): | |
| model_dropdown = gr.Dropdown( | |
| choices=[models[m]['display_name'] for m in models], | |
| label="Select Model", | |
| value=None | |
| ) | |
| model_info = gr.Markdown("**Model Information will appear here.**") | |
| display_to_name = {models[m]['display_name']: m for m in models} | |
| def update_model_info(selected_display_name): | |
| if not selected_display_name: | |
| return "Please select a model." | |
| model_name = display_to_name.get(selected_display_name) | |
| if not model_name: | |
| return "Model information not available." | |
| model_entry = models[model_name]['info'] | |
| return get_model_info(model_entry) | |
| model_dropdown.change( | |
| fn=update_model_info, | |
| inputs=model_dropdown, | |
| outputs=model_info | |
| ) | |
| with gr.Row(): | |
| confidence_slider = gr.Slider( | |
| minimum=0.0, | |
| maximum=1.0, | |
| step=0.01, | |
| value=0.25, | |
| label="Confidence Threshold", | |
| info="Adjust the minimum confidence required for detections to be displayed." | |
| ) | |
| with gr.Tab("🖼️ Image"): | |
| with gr.Column(): | |
| image_input = gr.Image( | |
| type='pil', | |
| label="Upload Image for Prediction" | |
| ) | |
| image_predict_btn = gr.Button("🔍 Predict on Image") | |
| image_status = gr.Markdown("**Status will appear here.**") | |
| image_output = gr.Image(label="Predicted Image") | |
| image_download_btn = gr.File(label="⬇️ Download Predicted Image") | |
| def process_image(selected_display_name, image, confidence): | |
| if not selected_display_name: | |
| return "❌ Please select a model.", None, None | |
| model_name = display_to_name.get(selected_display_name) | |
| return predict_image(model_name, image, confidence, models) | |
| image_predict_btn.click( | |
| fn=process_image, | |
| inputs=[model_dropdown, image_input, confidence_slider], | |
| outputs=[image_status, image_output, image_download_btn] | |
| ) | |
| gr.Markdown( | |
| """ | |
| --- | |
| **Note:** Models are downloaded from GitHub upon first use. Ensure that you have a stable internet connection and sufficient storage space. | |
| """ | |
| ) | |
| demo.launch() | |
| if __name__ == "__main__": | |
| main() |