Spaces:
Sleeping
Sleeping
| import subprocess | |
| import sys | |
| import os | |
| import torch | |
| import gradio as gr | |
| from PIL import Image | |
| import zipfile | |
| # Install ultralytics if not already installed | |
| try: | |
| from ultralytics import YOLO | |
| except ImportError: | |
| subprocess.check_call([sys.executable, "-m", "pip", "install", "ultralytics"]) | |
| from ultralytics import YOLO | |
| # Load the YOLOv8 model (ensure the path is correct) | |
| model = YOLO('best-3.pt') # Path to your trained YOLOv8 model | |
| # Global variables to store images and labels for carousel functionality | |
| processed_images = [] | |
| label_contents = [] | |
| processed_image_paths = [] | |
| # Temporary folder for saving processed files | |
| TEMP_DIR = "temp_processed" | |
| os.makedirs(TEMP_DIR, exist_ok=True) | |
| # Function to process all uploaded images | |
| def process_images(files): | |
| global processed_images, label_contents, processed_image_paths | |
| processed_images = [] | |
| label_contents = [] | |
| processed_image_paths = [] | |
| # Clear the temp directory | |
| for f in os.listdir(TEMP_DIR): | |
| os.remove(os.path.join(TEMP_DIR, f)) | |
| for i, file_path in enumerate(files): | |
| # Open the image using the file path | |
| img = Image.open(file_path) | |
| # Run inference | |
| results = model(img) | |
| # Convert results to Image format for display | |
| results.render() | |
| output_image = Image.fromarray(results.ims[0]) | |
| # Save the processed image and its labels | |
| processed_images.append(output_image) | |
| # Generate YOLO-format labels as a string | |
| label_content = "" | |
| for result in results: | |
| for box in result.boxes.xywh: # Extract bounding boxes in (x_center, y_center, width, height) format | |
| class_id = int(box[-1]) # Class ID is the last element in the result | |
| x_center, y_center, width, height = box[:4] # Extract coordinates and dimensions | |
| label_content += f"{class_id} {x_center} {y_center} {width} {height}\n" | |
| label_contents.append(label_content) | |
| # Save the image to the temp folder | |
| image_path = os.path.join(TEMP_DIR, f"processed_image_{i}.png") | |
| output_image.save(image_path) | |
| processed_image_paths.append(image_path) | |
| # Save the label content to a text file | |
| label_filename = f"annotation_{i}.txt" | |
| label_path = os.path.join(TEMP_DIR, label_filename) | |
| with open(label_path, "w") as label_file: | |
| label_file.write(label_content) | |
| # Return the first image and its labels | |
| if processed_images: | |
| return processed_images[0], label_contents[0], 0 # Start with index 0 | |
| else: | |
| return None, "No images found.", 0 | |
| # Function to create and return the path to the ZIP file for download | |
| def create_zip(): | |
| zip_filename = "processed_images_annotations.zip" | |
| zip_path = os.path.join(TEMP_DIR, zip_filename) | |
| # Remove existing ZIP file if it exists | |
| if os.path.exists(zip_path): | |
| os.remove(zip_path) | |
| with zipfile.ZipFile(zip_path, 'w') as z: | |
| # Add images and labels to the ZIP file | |
| for image_path in processed_image_paths: | |
| z.write(image_path, os.path.basename(image_path)) | |
| # Get index from image filename | |
| image_filename = os.path.basename(image_path) | |
| base_name, ext = os.path.splitext(image_filename) | |
| index = base_name.split('_')[-1] | |
| # Construct label filename | |
| label_filename = f"annotation_{index}.txt" | |
| label_path = os.path.join(TEMP_DIR, label_filename) | |
| z.write(label_path, label_filename) | |
| return zip_path # Return the file path as a string | |
| # Function to navigate through images | |
| def next_image(index): | |
| global processed_images, label_contents | |
| if processed_images: | |
| index = (index + 1) % len(processed_images) | |
| return processed_images[index], label_contents[index], index | |
| else: | |
| return None, "No images processed.", index | |
| def prev_image(index): | |
| global processed_images, label_contents | |
| if processed_images: | |
| index = (index - 1) % len(processed_images) | |
| return processed_images[index], label_contents[index], index | |
| else: | |
| return None, "No images processed.", index | |
| # Gradio interface | |
| with gr.Blocks() as interface: | |
| # Multiple file input and display area | |
| file_input = gr.Files(label="Upload multiple image files", type="filepath") | |
| image_display = gr.Image(label="Processed Image") | |
| label_display = gr.Textbox(label="Label File Content") | |
| # Buttons for carousel navigation | |
| prev_button = gr.Button("Previous Image") | |
| next_button = gr.Button("Next Image") | |
| # Hidden state to store current index | |
| current_index = gr.State(0) | |
| # Button to download all processed images and annotations as a ZIP file | |
| download_button = gr.Button("Prepare and Download All") | |
| download_file = gr.File() | |
| # Define functionality when files are uploaded | |
| file_input.change( | |
| process_images, | |
| inputs=file_input, | |
| outputs=[image_display, label_display, current_index] | |
| ) | |
| # Define functionality for next and previous buttons | |
| next_button.click( | |
| next_image, | |
| inputs=current_index, | |
| outputs=[image_display, label_display, current_index] | |
| ) | |
| prev_button.click( | |
| prev_image, | |
| inputs=current_index, | |
| outputs=[image_display, label_display, current_index] | |
| ) | |
| # Define functionality for the download button to zip the files and allow download | |
| def prepare_download(): | |
| zip_path = create_zip() | |
| return zip_path | |
| download_button.click( | |
| prepare_download, | |
| outputs=download_file | |
| ) | |
| # Launch the interface | |
| interface.launch(share=True) | |