File size: 5,711 Bytes
cc112f3
 
 
8749c44
 
 
 
cc112f3
 
 
 
 
 
 
8749c44
92837aa
 
8749c44
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92837aa
8749c44
 
 
 
 
 
92837aa
 
 
 
 
8749c44
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
import subprocess
import sys
import os
import torch
import gradio as gr
from PIL import Image
import zipfile

# Install ultralytics if not already installed
try:
    from ultralytics import YOLO
except ImportError:
    subprocess.check_call([sys.executable, "-m", "pip", "install", "ultralytics"])
    from ultralytics import YOLO

# Load the YOLOv8 model (ensure the path is correct)
model = YOLO('best-3.pt')  # Path to your trained YOLOv8 model

# Global variables to store images and labels for carousel functionality
processed_images = []
label_contents = []
processed_image_paths = []

# Temporary folder for saving processed files
TEMP_DIR = "temp_processed"
os.makedirs(TEMP_DIR, exist_ok=True)

# Function to process all uploaded images
def process_images(files):
    global processed_images, label_contents, processed_image_paths
    processed_images = []
    label_contents = []
    processed_image_paths = []

    # Clear the temp directory
    for f in os.listdir(TEMP_DIR):
        os.remove(os.path.join(TEMP_DIR, f))

    for i, file_path in enumerate(files):
        # Open the image using the file path
        img = Image.open(file_path)

        # Run inference
        results = model(img)

        # Convert results to Image format for display
        results.render()
        output_image = Image.fromarray(results.ims[0])

        # Save the processed image and its labels
        processed_images.append(output_image)

        # Generate YOLO-format labels as a string
        label_content = ""
        for result in results:
            for box in result.boxes.xywh:  # Extract bounding boxes in (x_center, y_center, width, height) format
                class_id = int(box[-1])  # Class ID is the last element in the result
                x_center, y_center, width, height = box[:4]  # Extract coordinates and dimensions
                label_content += f"{class_id} {x_center} {y_center} {width} {height}\n"
        label_contents.append(label_content)

        # Save the image to the temp folder
        image_path = os.path.join(TEMP_DIR, f"processed_image_{i}.png")
        output_image.save(image_path)
        processed_image_paths.append(image_path)

        # Save the label content to a text file
        label_filename = f"annotation_{i}.txt"
        label_path = os.path.join(TEMP_DIR, label_filename)
        with open(label_path, "w") as label_file:
            label_file.write(label_content)

    # Return the first image and its labels
    if processed_images:
        return processed_images[0], label_contents[0], 0  # Start with index 0
    else:
        return None, "No images found.", 0

# Function to create and return the path to the ZIP file for download
def create_zip():
    zip_filename = "processed_images_annotations.zip"
    zip_path = os.path.join(TEMP_DIR, zip_filename)

    # Remove existing ZIP file if it exists
    if os.path.exists(zip_path):
        os.remove(zip_path)

    with zipfile.ZipFile(zip_path, 'w') as z:
        # Add images and labels to the ZIP file
        for image_path in processed_image_paths:
            z.write(image_path, os.path.basename(image_path))
            # Get index from image filename
            image_filename = os.path.basename(image_path)
            base_name, ext = os.path.splitext(image_filename)
            index = base_name.split('_')[-1]
            # Construct label filename
            label_filename = f"annotation_{index}.txt"
            label_path = os.path.join(TEMP_DIR, label_filename)
            z.write(label_path, label_filename)

    return zip_path  # Return the file path as a string

# Function to navigate through images
def next_image(index):
    global processed_images, label_contents
    if processed_images:
        index = (index + 1) % len(processed_images)
        return processed_images[index], label_contents[index], index
    else:
        return None, "No images processed.", index

def prev_image(index):
    global processed_images, label_contents
    if processed_images:
        index = (index - 1) % len(processed_images)
        return processed_images[index], label_contents[index], index
    else:
        return None, "No images processed.", index

# Gradio interface
with gr.Blocks() as interface:
    # Multiple file input and display area
    file_input = gr.Files(label="Upload multiple image files", type="filepath")
    image_display = gr.Image(label="Processed Image")
    label_display = gr.Textbox(label="Label File Content")

    # Buttons for carousel navigation
    prev_button = gr.Button("Previous Image")
    next_button = gr.Button("Next Image")

    # Hidden state to store current index
    current_index = gr.State(0)

    # Button to download all processed images and annotations as a ZIP file
    download_button = gr.Button("Prepare and Download All")
    download_file = gr.File()

    # Define functionality when files are uploaded
    file_input.change(
        process_images,
        inputs=file_input,
        outputs=[image_display, label_display, current_index]
    )

    # Define functionality for next and previous buttons
    next_button.click(
        next_image,
        inputs=current_index,
        outputs=[image_display, label_display, current_index]
    )
    prev_button.click(
        prev_image,
        inputs=current_index,
        outputs=[image_display, label_display, current_index]
    )

    # Define functionality for the download button to zip the files and allow download
    def prepare_download():
        zip_path = create_zip()
        return zip_path

    download_button.click(
        prepare_download,
        outputs=download_file
    )

# Launch the interface
interface.launch(share=True)