import json from pathlib import Path import cv2 import numpy as np import streamlit as st import torch from ultralytics import YOLO from tomato_pipeline import load_classifier, make_transform, classify_crop # ------------------------- # 🔥 FIX: UPLOAD LIMIT (403 ERROR FIX) # ------------------------- st.set_option('server.maxUploadSize', 10) # 10MB # ------------------------- # CONFIG # ------------------------- st.set_page_config( page_title="Tomato AI Inspector", page_icon="🍅", layout="wide" ) st.title("🍅 Tomato AI Quality Inspector") st.caption("YOLO Detection + EfficientNet Classification") DETECTOR_PATH = Path("best.pt") CLASSIFIER_PATH = Path("efficientnet_b0_best.pth") device = "cuda" if torch.cuda.is_available() else "cpu" # ------------------------- # LOAD MODELS # ------------------------- @st.cache_resource def load_models(): detector = YOLO(str(DETECTOR_PATH)) classifier = load_classifier(CLASSIFIER_PATH, device) return detector, classifier detector, classifier = load_models() # ------------------------- # INPUT # ------------------------- uploaded = st.file_uploader( "Upload Tomato Image", type=["jpg", "png", "jpeg"], accept_multiple_files=False ) use_sample = st.button("Use Sample Image") # ------------------------- # IMAGE LOAD FUNCTION # ------------------------- def load_image(uploaded_file=None): # 👉 SAMPLE IMAGE FALLBACK if use_sample: try: with open("sample.jpg", "rb") as f: file_bytes = f.read() except: st.error("Sample image not found.") return None # 👉 USER UPLOAD elif uploaded_file is not None: uploaded_file.file.seek(0) file_bytes = uploaded_file.file.read() if not file_bytes: st.error("Upload failed. Try again.") return None else: return None # 👉 DECODE IMAGE image_np = cv2.imdecode( np.frombuffer(file_bytes, np.uint8), cv2.IMREAD_COLOR ) if image_np is None: st.error("Invalid image file.") return None return image_np # ------------------------- # RUN BUTTON # ------------------------- run = st.button("Run Detection") # ------------------------- # INFERENCE # ------------------------- if run: image_np = load_image(uploaded) if image_np is None: st.warning("Please upload or select an image.") st.stop() h, w = image_np.shape[:2] output = image_np.copy() transform = make_transform(224) detections = detector.predict( source=image_np, conf=0.25, device=device, verbose=False ) results = [] good_count = 0 bad_count = 0 if detections and detections[0].boxes is not None: for box in detections[0].boxes: x1, y1, x2, y2 = box.xyxy[0].tolist() x1, y1, x2, y2 = map(int, [x1, y1, x2, y2]) crop = image_np[y1:y2, x1:x2] if crop.size == 0: continue label, conf = classify_crop( crop, classifier, transform, device, ["bad", "good"] ) if label.lower() == "good": good_count += 1 color = (0, 255, 0) else: bad_count += 1 color = (0, 0, 255) cv2.rectangle(output, (x1, y1), (x2, y2), color, 2) cv2.putText( output, f"{label} {conf:.2f}", (x1, y1 - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.6, color, 2 ) results.append(label) # ------------------------- # DISPLAY # ------------------------- col1, col2 = st.columns(2) with col1: st.subheader("Input Image") st.image(image_np, channels="BGR", use_container_width=True) with col2: st.subheader("Detection Result") st.image(cv2.cvtColor(output, cv2.COLOR_BGR2RGB), use_container_width=True) st.success(f"Total: {len(results)} | Good: {good_count} | Bad: {bad_count}")