Spaces:
Sleeping
Sleeping
| import json | |
| from pathlib import Path | |
| import cv2 | |
| import numpy as np | |
| import streamlit as st | |
| import torch | |
| from ultralytics import YOLO | |
| from tomato_pipeline import load_classifier, make_transform, classify_crop | |
| # ------------------------- | |
| # π₯ FIX: UPLOAD LIMIT (403 ERROR FIX) | |
| # ------------------------- | |
| st.set_option('server.maxUploadSize', 10) # 10MB | |
| # ------------------------- | |
| # CONFIG | |
| # ------------------------- | |
| st.set_page_config( | |
| page_title="Tomato AI Inspector", | |
| page_icon="π ", | |
| layout="wide" | |
| ) | |
| st.title("π Tomato AI Quality Inspector") | |
| st.caption("YOLO Detection + EfficientNet Classification") | |
| DETECTOR_PATH = Path("best.pt") | |
| CLASSIFIER_PATH = Path("efficientnet_b0_best.pth") | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| # ------------------------- | |
| # LOAD MODELS | |
| # ------------------------- | |
| def load_models(): | |
| detector = YOLO(str(DETECTOR_PATH)) | |
| classifier = load_classifier(CLASSIFIER_PATH, device) | |
| return detector, classifier | |
| detector, classifier = load_models() | |
| # ------------------------- | |
| # INPUT | |
| # ------------------------- | |
| uploaded = st.file_uploader( | |
| "Upload Tomato Image", | |
| type=["jpg", "png", "jpeg"], | |
| accept_multiple_files=False | |
| ) | |
| use_sample = st.button("Use Sample Image") | |
| # ------------------------- | |
| # IMAGE LOAD FUNCTION | |
| # ------------------------- | |
| def load_image(uploaded_file=None): | |
| # π SAMPLE IMAGE FALLBACK | |
| if use_sample: | |
| try: | |
| with open("sample.jpg", "rb") as f: | |
| file_bytes = f.read() | |
| except: | |
| st.error("Sample image not found.") | |
| return None | |
| # π USER UPLOAD | |
| elif uploaded_file is not None: | |
| uploaded_file.file.seek(0) | |
| file_bytes = uploaded_file.file.read() | |
| if not file_bytes: | |
| st.error("Upload failed. Try again.") | |
| return None | |
| else: | |
| return None | |
| # π DECODE IMAGE | |
| image_np = cv2.imdecode( | |
| np.frombuffer(file_bytes, np.uint8), | |
| cv2.IMREAD_COLOR | |
| ) | |
| if image_np is None: | |
| st.error("Invalid image file.") | |
| return None | |
| return image_np | |
| # ------------------------- | |
| # RUN BUTTON | |
| # ------------------------- | |
| run = st.button("Run Detection") | |
| # ------------------------- | |
| # INFERENCE | |
| # ------------------------- | |
| if run: | |
| image_np = load_image(uploaded) | |
| if image_np is None: | |
| st.warning("Please upload or select an image.") | |
| st.stop() | |
| h, w = image_np.shape[:2] | |
| output = image_np.copy() | |
| transform = make_transform(224) | |
| detections = detector.predict( | |
| source=image_np, | |
| conf=0.25, | |
| device=device, | |
| verbose=False | |
| ) | |
| results = [] | |
| good_count = 0 | |
| bad_count = 0 | |
| if detections and detections[0].boxes is not None: | |
| for box in detections[0].boxes: | |
| x1, y1, x2, y2 = box.xyxy[0].tolist() | |
| x1, y1, x2, y2 = map(int, [x1, y1, x2, y2]) | |
| crop = image_np[y1:y2, x1:x2] | |
| if crop.size == 0: | |
| continue | |
| label, conf = classify_crop( | |
| crop, | |
| classifier, | |
| transform, | |
| device, | |
| ["bad", "good"] | |
| ) | |
| if label.lower() == "good": | |
| good_count += 1 | |
| color = (0, 255, 0) | |
| else: | |
| bad_count += 1 | |
| color = (0, 0, 255) | |
| cv2.rectangle(output, (x1, y1), (x2, y2), color, 2) | |
| cv2.putText( | |
| output, | |
| f"{label} {conf:.2f}", | |
| (x1, y1 - 5), | |
| cv2.FONT_HERSHEY_SIMPLEX, | |
| 0.6, | |
| color, | |
| 2 | |
| ) | |
| results.append(label) | |
| # ------------------------- | |
| # DISPLAY | |
| # ------------------------- | |
| col1, col2 = st.columns(2) | |
| with col1: | |
| st.subheader("Input Image") | |
| st.image(image_np, channels="BGR", use_container_width=True) | |
| with col2: | |
| st.subheader("Detection Result") | |
| st.image(cv2.cvtColor(output, cv2.COLOR_BGR2RGB), use_container_width=True) | |
| st.success(f"Total: {len(results)} | Good: {good_count} | Bad: {bad_count}") |