sparsh007 commited on
Commit
bd507a2
·
verified ·
1 Parent(s): 30eaedd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -28
app.py CHANGED
@@ -3,15 +3,10 @@ import gradio as gr
3
  from PIL import Image
4
  import os
5
  import zipfile
6
- import io
7
- import shutil
8
 
9
- # Clear the torch hub cache to prevent any previous corrupted downloads
10
- shutil.rmtree(torch.hub._get_torch_home(), ignore_errors=True)
11
-
12
- # Load the YOLOv5 model from the uploaded best.pt file
13
  model = torch.hub.load(
14
- 'ultralytics/yolov5', 'custom', path='best-3.pt', trust_repo=True
15
  )
16
 
17
  # Global variables to store images and labels for carousel functionality
@@ -21,10 +16,7 @@ processed_image_paths = []
21
 
22
  # Temporary folder for saving processed files
23
  TEMP_DIR = "temp_processed"
24
-
25
- # Ensure the temp directory exists
26
- if not os.path.exists(TEMP_DIR):
27
- os.makedirs(TEMP_DIR)
28
 
29
  # Function to process all uploaded images
30
  def process_images(files):
@@ -79,10 +71,16 @@ def process_images(files):
79
  else:
80
  return None, "No images found.", 0
81
 
82
- # Function to create and return a ZIP file for download
83
  def create_zip():
84
- zip_buffer = io.BytesIO()
85
- with zipfile.ZipFile(zip_buffer, 'w') as z:
 
 
 
 
 
 
86
  # Add images and labels to the ZIP file
87
  for image_path in processed_image_paths:
88
  z.write(image_path, os.path.basename(image_path))
@@ -95,9 +93,7 @@ def create_zip():
95
  label_path = os.path.join(TEMP_DIR, label_filename)
96
  z.write(label_path, label_filename)
97
 
98
- zip_buffer.seek(0) # Go to the start of the buffer
99
- # Return the bytes of the ZIP file
100
- return zip_buffer.getvalue()
101
 
102
  # Function to navigate through images
103
  def next_image(index):
@@ -130,11 +126,9 @@ with gr.Blocks() as interface:
130
  # Hidden state to store current index
131
  current_index = gr.State(0)
132
 
133
- # Button to prepare the ZIP file for download
134
- prepare_download_button = gr.Button("Prepare Download")
135
-
136
- # Download button
137
- download_button = gr.DownloadButton(label="Download ZIP", visible=False)
138
 
139
  # Define functionality when files are uploaded
140
  file_input.change(
@@ -155,15 +149,15 @@ with gr.Blocks() as interface:
155
  outputs=[image_display, label_display, current_index]
156
  )
157
 
158
- # Define functionality for the prepare download button
159
  def prepare_download():
160
- data = create_zip()
161
- return gr.update(value={"name": "processed_images_annotations.zip", "data": data}, visible=True)
162
 
163
- prepare_download_button.click(
164
  prepare_download,
165
- outputs=download_button
166
  )
167
 
168
  # Launch the interface
169
- interface.launch()
 
3
  from PIL import Image
4
  import os
5
  import zipfile
 
 
6
 
7
+ # Load the YOLOv5 model (ensure the path is correct)
 
 
 
8
  model = torch.hub.load(
9
+ 'ultralytics/yolov5', 'custom', path='best.pt'
10
  )
11
 
12
  # Global variables to store images and labels for carousel functionality
 
16
 
17
  # Temporary folder for saving processed files
18
  TEMP_DIR = "temp_processed"
19
+ os.makedirs(TEMP_DIR, exist_ok=True)
 
 
 
20
 
21
  # Function to process all uploaded images
22
  def process_images(files):
 
71
  else:
72
  return None, "No images found.", 0
73
 
74
+ # Function to create and return the path to the ZIP file for download
75
  def create_zip():
76
+ zip_filename = "processed_images_annotations.zip"
77
+ zip_path = os.path.join(TEMP_DIR, zip_filename)
78
+
79
+ # Remove existing ZIP file if it exists
80
+ if os.path.exists(zip_path):
81
+ os.remove(zip_path)
82
+
83
+ with zipfile.ZipFile(zip_path, 'w') as z:
84
  # Add images and labels to the ZIP file
85
  for image_path in processed_image_paths:
86
  z.write(image_path, os.path.basename(image_path))
 
93
  label_path = os.path.join(TEMP_DIR, label_filename)
94
  z.write(label_path, label_filename)
95
 
96
+ return zip_path # Return the file path as a string
 
 
97
 
98
  # Function to navigate through images
99
  def next_image(index):
 
126
  # Hidden state to store current index
127
  current_index = gr.State(0)
128
 
129
+ # Button to download all processed images and annotations as a ZIP file
130
+ download_button = gr.Button("Prepare and Download All")
131
+ download_file = gr.File()
 
 
132
 
133
  # Define functionality when files are uploaded
134
  file_input.change(
 
149
  outputs=[image_display, label_display, current_index]
150
  )
151
 
152
+ # Define functionality for the download button to zip the files and allow download
153
  def prepare_download():
154
+ zip_path = create_zip()
155
+ return zip_path
156
 
157
+ download_button.click(
158
  prepare_download,
159
+ outputs=download_file
160
  )
161
 
162
  # Launch the interface
163
+ interface.launch(share=True)