sparsh007's picture
Update app.py
cc112f3 verified
import subprocess
import sys
import os
import torch
import gradio as gr
from PIL import Image
import zipfile
# Install ultralytics if not already installed
try:
from ultralytics import YOLO
except ImportError:
subprocess.check_call([sys.executable, "-m", "pip", "install", "ultralytics"])
from ultralytics import YOLO
# Load the YOLOv8 model (ensure the path is correct)
model = YOLO('best-3.pt') # Path to your trained YOLOv8 model
# Global variables to store images and labels for carousel functionality
processed_images = []
label_contents = []
processed_image_paths = []
# Temporary folder for saving processed files
TEMP_DIR = "temp_processed"
os.makedirs(TEMP_DIR, exist_ok=True)
# Function to process all uploaded images
def process_images(files):
global processed_images, label_contents, processed_image_paths
processed_images = []
label_contents = []
processed_image_paths = []
# Clear the temp directory
for f in os.listdir(TEMP_DIR):
os.remove(os.path.join(TEMP_DIR, f))
for i, file_path in enumerate(files):
# Open the image using the file path
img = Image.open(file_path)
# Run inference
results = model(img)
# Convert results to Image format for display
results.render()
output_image = Image.fromarray(results.ims[0])
# Save the processed image and its labels
processed_images.append(output_image)
# Generate YOLO-format labels as a string
label_content = ""
for result in results:
for box in result.boxes.xywh: # Extract bounding boxes in (x_center, y_center, width, height) format
class_id = int(box[-1]) # Class ID is the last element in the result
x_center, y_center, width, height = box[:4] # Extract coordinates and dimensions
label_content += f"{class_id} {x_center} {y_center} {width} {height}\n"
label_contents.append(label_content)
# Save the image to the temp folder
image_path = os.path.join(TEMP_DIR, f"processed_image_{i}.png")
output_image.save(image_path)
processed_image_paths.append(image_path)
# Save the label content to a text file
label_filename = f"annotation_{i}.txt"
label_path = os.path.join(TEMP_DIR, label_filename)
with open(label_path, "w") as label_file:
label_file.write(label_content)
# Return the first image and its labels
if processed_images:
return processed_images[0], label_contents[0], 0 # Start with index 0
else:
return None, "No images found.", 0
# Function to create and return the path to the ZIP file for download
def create_zip():
zip_filename = "processed_images_annotations.zip"
zip_path = os.path.join(TEMP_DIR, zip_filename)
# Remove existing ZIP file if it exists
if os.path.exists(zip_path):
os.remove(zip_path)
with zipfile.ZipFile(zip_path, 'w') as z:
# Add images and labels to the ZIP file
for image_path in processed_image_paths:
z.write(image_path, os.path.basename(image_path))
# Get index from image filename
image_filename = os.path.basename(image_path)
base_name, ext = os.path.splitext(image_filename)
index = base_name.split('_')[-1]
# Construct label filename
label_filename = f"annotation_{index}.txt"
label_path = os.path.join(TEMP_DIR, label_filename)
z.write(label_path, label_filename)
return zip_path # Return the file path as a string
# Function to navigate through images
def next_image(index):
global processed_images, label_contents
if processed_images:
index = (index + 1) % len(processed_images)
return processed_images[index], label_contents[index], index
else:
return None, "No images processed.", index
def prev_image(index):
global processed_images, label_contents
if processed_images:
index = (index - 1) % len(processed_images)
return processed_images[index], label_contents[index], index
else:
return None, "No images processed.", index
# Gradio interface
with gr.Blocks() as interface:
# Multiple file input and display area
file_input = gr.Files(label="Upload multiple image files", type="filepath")
image_display = gr.Image(label="Processed Image")
label_display = gr.Textbox(label="Label File Content")
# Buttons for carousel navigation
prev_button = gr.Button("Previous Image")
next_button = gr.Button("Next Image")
# Hidden state to store current index
current_index = gr.State(0)
# Button to download all processed images and annotations as a ZIP file
download_button = gr.Button("Prepare and Download All")
download_file = gr.File()
# Define functionality when files are uploaded
file_input.change(
process_images,
inputs=file_input,
outputs=[image_display, label_display, current_index]
)
# Define functionality for next and previous buttons
next_button.click(
next_image,
inputs=current_index,
outputs=[image_display, label_display, current_index]
)
prev_button.click(
prev_image,
inputs=current_index,
outputs=[image_display, label_display, current_index]
)
# Define functionality for the download button to zip the files and allow download
def prepare_download():
zip_path = create_zip()
return zip_path
download_button.click(
prepare_download,
outputs=download_file
)
# Launch the interface
interface.launch(share=True)