Dreamy0 commited on
Commit
efb2242
Β·
verified Β·
1 Parent(s): e819a5b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +110 -92
app.py CHANGED
@@ -4,7 +4,7 @@ import zipfile
4
  import warnings
5
  warnings.filterwarnings('ignore')
6
 
7
- import streamlit as st
8
  import torch
9
  import cv2
10
  import numpy as np
@@ -15,23 +15,30 @@ from torchvision.models.detection import fasterrcnn_resnet50_fpn
15
  from torchvision.models.detection import FasterRCNN_ResNet50_FPN_Weights
16
  from pathlib import Path
17
 
18
- @st.cache_resource
 
 
 
19
  def load_model(model_path="seed_frcnn.pth", num_classes=2):
20
  """Load the trained Faster R-CNN model"""
21
- device = 'cuda' if torch.cuda.is_available() else 'cpu'
 
 
 
 
 
22
 
23
  try:
24
- model = fasterrcnn_resnet50_fpn(weights=None)
25
- model.roi_heads.box_predictor = torchvision.models.detection.faster_rcnn.FastRCNNPredictor(
26
- model.roi_heads.box_predictor.cls_score.in_features, num_classes
27
  )
28
- model.load_state_dict(torch.load(model_path, map_location=device))
29
- model.to(device)
30
- model.eval()
31
- return model, device
32
  except Exception as e:
33
- st.error(f"Error loading model: {str(e)}")
34
- return None, None
35
 
36
  def run_inference(model, device, img_pil, score_thresh=0.5):
37
  """Run inference on a PIL image"""
@@ -69,19 +76,24 @@ def draw_boxes(image, detections):
69
 
70
  return cv2.cvtColor(img_cv, cv2.COLOR_BGR2RGB)
71
 
72
- def process_images(model, device, images, threshold, folder_name):
73
  """Process uploaded images and return crops as a ZIP file"""
 
 
 
 
 
74
  if not folder_name:
75
  folder_name = "seed_crops"
76
 
77
  folder_name = "".join(c for c in folder_name if c.isalnum() or c in ('-', '_'))
78
 
79
- zip_buffer = io.BytesIO()
80
  total_crops = 0
81
  processed_images = 0
82
  preview_images = []
83
 
84
- with zipfile.ZipFile(zip_buffer, 'w', zipfile.ZIP_DEFLATED) as zip_file:
85
  for idx, img_file in enumerate(images):
86
  try:
87
  img = Image.open(img_file).convert("RGB")
@@ -116,88 +128,94 @@ def process_images(model, device, images, threshold, folder_name):
116
  processed_images += 1
117
 
118
  except Exception as e:
119
- st.error(f"Error processing image {idx}: {str(e)}")
120
  continue
121
 
122
- zip_buffer.seek(0)
123
- return zip_buffer, total_crops, processed_images, preview_images
124
-
125
- # Streamlit UI
126
- st.set_page_config(page_title="Seed Detection & Cropping", page_icon="🌱", layout="wide")
127
-
128
- st.title("🌱 Seed Detection & Cropping Tool")
129
- st.markdown("Upload images to detect seeds using AI-powered Faster R-CNN model. Get all detected seeds as individual cropped images in a ZIP file.")
130
-
131
- # Load model
132
- with st.spinner("Loading model..."):
133
- model, device = load_model("seed_frcnn.pth")
134
-
135
- if model is None:
136
- st.error("Failed to load model. Please check if seed_frcnn.pth exists.")
137
- st.stop()
138
-
139
- st.success(f"βœ… Model loaded successfully on {device}!")
140
-
141
- # Sidebar controls
142
- st.sidebar.header("βš™οΈ Settings")
143
- threshold = st.sidebar.slider("🎚️ Detection Threshold", min_value=0.1, max_value=0.95, value=0.5, step=0.05)
144
- folder_name = st.sidebar.text_input("πŸ“ Output Folder Name", value="seed_crops")
145
 
146
- st.sidebar.markdown("""
147
- ### πŸ’‘ Tips:
148
- - **Lower threshold (0.3-0.5)**: More detections, may include false positives
149
- - **Higher threshold (0.6-0.8)**: Conservative, high-confidence only
150
- """)
 
151
 
152
- # File uploader
153
- uploaded_files = st.file_uploader("πŸ“€ Upload Images", type=["jpg", "jpeg", "png", "bmp"], accept_multiple_files=True)
154
-
155
- if uploaded_files:
156
- st.info(f"πŸ“· {len(uploaded_files)} image(s) uploaded")
157
 
158
- if st.button("πŸ” Detect & Crop Seeds", type="primary"):
159
- with st.spinner("Processing images..."):
160
- zip_buffer, total_crops, processed_images, preview_images = process_images(
161
- model, device, uploaded_files, threshold, folder_name
 
 
 
 
 
 
 
 
 
 
 
162
  )
163
 
164
- if total_crops > 0:
165
- st.success(f"βœ… Processed {processed_images} images")
166
- st.success(f"🌱 Detected and saved {total_crops} seed crops")
167
-
168
- # Download button
169
- st.download_button(
170
- label="πŸ’Ύ Download ZIP File",
171
- data=zip_buffer.getvalue(),
172
- file_name=f"{folder_name}.zip",
173
- mime="application/zip"
174
- )
175
-
176
- # Preview
177
- if preview_images:
178
- st.markdown("### πŸ–ΌοΈ Preview (First 3 images with detections)")
179
- cols = st.columns(min(3, len(preview_images)))
180
- for idx, preview_img in enumerate(preview_images):
181
- cols[idx].image(preview_img, use_column_width=True)
182
- else:
183
- st.warning("⚠️ No seeds detected! Try lowering the threshold.")
184
- else:
185
- st.info("πŸ‘† Upload images to get started")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
186
 
187
- # Instructions
188
- with st.expander("πŸ“‹ How to Use"):
189
- st.markdown("""
190
- 1. **Upload Images**: Click the upload box and select one or multiple images
191
- 2. **Adjust Threshold**: Use the slider in the sidebar to control detection sensitivity
192
- 3. **Set Folder Name**: Enter the folder name for your crops
193
- 4. **Click Detect**: Process your images
194
- 5. **Download**: Get your ZIP file with all cropped seeds!
195
-
196
- ### Output Format:
197
- ```
198
- your_folder_name/
199
- β”œβ”€β”€ image1_seed_000_score_0.850.jpg
200
- β”œβ”€β”€ image1_seed_001_score_0.720.jpg
201
- └── image2_seed_000_score_0.910.jpg
202
- ```
203
- """)
 
4
  import warnings
5
  warnings.filterwarnings('ignore')
6
 
7
+ import gradio as gr
8
  import torch
9
  import cv2
10
  import numpy as np
 
15
  from torchvision.models.detection import FasterRCNN_ResNet50_FPN_Weights
16
  from pathlib import Path
17
 
18
+ # Global model variables
19
+ MODEL = None
20
+ DEVICE = None
21
+
22
  def load_model(model_path="seed_frcnn.pth", num_classes=2):
23
  """Load the trained Faster R-CNN model"""
24
+ global MODEL, DEVICE
25
+
26
+ if MODEL is not None:
27
+ return MODEL, DEVICE
28
+
29
+ DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
30
 
31
  try:
32
+ MODEL = fasterrcnn_resnet50_fpn(weights=None)
33
+ MODEL.roi_heads.box_predictor = torchvision.models.detection.faster_rcnn.FastRCNNPredictor(
34
+ MODEL.roi_heads.box_predictor.cls_score.in_features, num_classes
35
  )
36
+ MODEL.load_state_dict(torch.load(model_path, map_location=DEVICE))
37
+ MODEL.to(DEVICE)
38
+ MODEL.eval()
39
+ return MODEL, DEVICE
40
  except Exception as e:
41
+ raise Exception(f"Error loading model: {str(e)}")
 
42
 
43
  def run_inference(model, device, img_pil, score_thresh=0.5):
44
  """Run inference on a PIL image"""
 
76
 
77
  return cv2.cvtColor(img_cv, cv2.COLOR_BGR2RGB)
78
 
79
+ def process_images(images, threshold, folder_name):
80
  """Process uploaded images and return crops as a ZIP file"""
81
+ if not images:
82
+ return None, "❌ No images uploaded!", None, None, None
83
+
84
+ model, device = load_model()
85
+
86
  if not folder_name:
87
  folder_name = "seed_crops"
88
 
89
  folder_name = "".join(c for c in folder_name if c.isalnum() or c in ('-', '_'))
90
 
91
+ zip_path = f"{folder_name}.zip"
92
  total_crops = 0
93
  processed_images = 0
94
  preview_images = []
95
 
96
+ with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) as zip_file:
97
  for idx, img_file in enumerate(images):
98
  try:
99
  img = Image.open(img_file).convert("RGB")
 
128
  processed_images += 1
129
 
130
  except Exception as e:
131
+ print(f"Error processing image {idx}: {str(e)}")
132
  continue
133
 
134
+ if total_crops > 0:
135
+ status_msg = f"βœ… Processed {processed_images} images\n🌱 Detected and saved {total_crops} seed crops"
136
+ preview1 = preview_images[0] if len(preview_images) > 0 else None
137
+ preview2 = preview_images[1] if len(preview_images) > 1 else None
138
+ preview3 = preview_images[2] if len(preview_images) > 2 else None
139
+ return zip_path, status_msg, preview1, preview2, preview3
140
+ else:
141
+ return None, "⚠️ No seeds detected! Try lowering the threshold.", None, None, None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
142
 
143
+ # Initialize model on startup
144
+ try:
145
+ load_model()
146
+ model_status = "βœ… Model loaded successfully!"
147
+ except Exception as e:
148
+ model_status = f"❌ Error loading model: {str(e)}"
149
 
150
+ # Create Gradio interface
151
+ with gr.Blocks(title="Seed Detection & Cropping", theme=gr.themes.Soft()) as demo:
152
+ gr.Markdown("# 🌱 Seed Detection & Cropping Tool")
153
+ gr.Markdown("Upload images to detect seeds using AI-powered Faster R-CNN model. Get all detected seeds as individual cropped images in a ZIP file.")
154
+ gr.Markdown(f"**Model Status:** {model_status}")
155
 
156
+ with gr.Row():
157
+ with gr.Column(scale=1):
158
+ gr.Markdown("## βš™οΈ Settings")
159
+ threshold = gr.Slider(
160
+ minimum=0.1,
161
+ maximum=0.95,
162
+ value=0.5,
163
+ step=0.05,
164
+ label="🎚️ Detection Threshold",
165
+ info="Lower = more detections, Higher = only high-confidence"
166
+ )
167
+ folder_name = gr.Textbox(
168
+ value="seed_crops",
169
+ label="πŸ“ Output Folder Name",
170
+ placeholder="seed_crops"
171
  )
172
 
173
+ gr.Markdown("""
174
+ ### πŸ’‘ Tips:
175
+ - **Lower threshold (0.3-0.5)**: More detections, may include false positives
176
+ - **Higher threshold (0.6-0.8)**: Conservative, high-confidence only
177
+ """)
178
+
179
+ with gr.Column(scale=2):
180
+ images = gr.File(
181
+ file_count="multiple",
182
+ label="πŸ“€ Upload Images",
183
+ file_types=["image"]
184
+ )
185
+
186
+ process_btn = gr.Button("πŸ” Detect & Crop Seeds", variant="primary", size="lg")
187
+
188
+ status = gr.Textbox(label="Status", interactive=False)
189
+ download = gr.File(label="πŸ’Ύ Download ZIP File")
190
+
191
+ gr.Markdown("### πŸ–ΌοΈ Preview (First 3 images with detections)")
192
+ with gr.Row():
193
+ preview1 = gr.Image(label="Preview 1")
194
+ preview2 = gr.Image(label="Preview 2")
195
+ preview3 = gr.Image(label="Preview 3")
196
+
197
+ with gr.Accordion("πŸ“‹ How to Use", open=False):
198
+ gr.Markdown("""
199
+ 1. **Upload Images**: Click the upload box and select one or multiple images
200
+ 2. **Adjust Threshold**: Use the slider to control detection sensitivity
201
+ 3. **Set Folder Name**: Enter the folder name for your crops
202
+ 4. **Click Detect**: Process your images
203
+ 5. **Download**: Get your ZIP file with all cropped seeds!
204
+
205
+ ### Output Format:
206
+ ```
207
+ your_folder_name/
208
+ β”œβ”€β”€ image1_seed_000_score_0.850.jpg
209
+ β”œβ”€β”€ image1_seed_001_score_0.720.jpg
210
+ └── image2_seed_000_score_0.910.jpg
211
+ ```
212
+ """)
213
+
214
+ process_btn.click(
215
+ fn=process_images,
216
+ inputs=[images, threshold, folder_name],
217
+ outputs=[download, status, preview1, preview2, preview3]
218
+ )
219
 
220
+ if __name__ == "__main__":
221
+ demo.launch()