Tomato-Sorting / app.py
krishnabalaji's picture
Update app.py
50c74df verified
Raw
History Blame Contribute Delete
4.23 kB
import json
from pathlib import Path
import cv2
import numpy as np
import streamlit as st
import torch
from ultralytics import YOLO
from tomato_pipeline import load_classifier, make_transform, classify_crop
# -------------------------
# πŸ”₯ FIX: UPLOAD LIMIT (403 ERROR FIX)
# -------------------------
st.set_option('server.maxUploadSize', 10) # 10MB
# -------------------------
# CONFIG
# -------------------------
st.set_page_config(
page_title="Tomato AI Inspector",
page_icon="πŸ…",
layout="wide"
)
st.title("πŸ… Tomato AI Quality Inspector")
st.caption("YOLO Detection + EfficientNet Classification")
DETECTOR_PATH = Path("best.pt")
CLASSIFIER_PATH = Path("efficientnet_b0_best.pth")
device = "cuda" if torch.cuda.is_available() else "cpu"
# -------------------------
# LOAD MODELS
# -------------------------
@st.cache_resource
def load_models():
detector = YOLO(str(DETECTOR_PATH))
classifier = load_classifier(CLASSIFIER_PATH, device)
return detector, classifier
detector, classifier = load_models()
# -------------------------
# INPUT
# -------------------------
uploaded = st.file_uploader(
"Upload Tomato Image",
type=["jpg", "png", "jpeg"],
accept_multiple_files=False
)
use_sample = st.button("Use Sample Image")
# -------------------------
# IMAGE LOAD FUNCTION
# -------------------------
def load_image(uploaded_file=None):
# πŸ‘‰ SAMPLE IMAGE FALLBACK
if use_sample:
try:
with open("sample.jpg", "rb") as f:
file_bytes = f.read()
except:
st.error("Sample image not found.")
return None
# πŸ‘‰ USER UPLOAD
elif uploaded_file is not None:
uploaded_file.file.seek(0)
file_bytes = uploaded_file.file.read()
if not file_bytes:
st.error("Upload failed. Try again.")
return None
else:
return None
# πŸ‘‰ DECODE IMAGE
image_np = cv2.imdecode(
np.frombuffer(file_bytes, np.uint8),
cv2.IMREAD_COLOR
)
if image_np is None:
st.error("Invalid image file.")
return None
return image_np
# -------------------------
# RUN BUTTON
# -------------------------
run = st.button("Run Detection")
# -------------------------
# INFERENCE
# -------------------------
if run:
image_np = load_image(uploaded)
if image_np is None:
st.warning("Please upload or select an image.")
st.stop()
h, w = image_np.shape[:2]
output = image_np.copy()
transform = make_transform(224)
detections = detector.predict(
source=image_np,
conf=0.25,
device=device,
verbose=False
)
results = []
good_count = 0
bad_count = 0
if detections and detections[0].boxes is not None:
for box in detections[0].boxes:
x1, y1, x2, y2 = box.xyxy[0].tolist()
x1, y1, x2, y2 = map(int, [x1, y1, x2, y2])
crop = image_np[y1:y2, x1:x2]
if crop.size == 0:
continue
label, conf = classify_crop(
crop,
classifier,
transform,
device,
["bad", "good"]
)
if label.lower() == "good":
good_count += 1
color = (0, 255, 0)
else:
bad_count += 1
color = (0, 0, 255)
cv2.rectangle(output, (x1, y1), (x2, y2), color, 2)
cv2.putText(
output,
f"{label} {conf:.2f}",
(x1, y1 - 5),
cv2.FONT_HERSHEY_SIMPLEX,
0.6,
color,
2
)
results.append(label)
# -------------------------
# DISPLAY
# -------------------------
col1, col2 = st.columns(2)
with col1:
st.subheader("Input Image")
st.image(image_np, channels="BGR", use_container_width=True)
with col2:
st.subheader("Detection Result")
st.image(cv2.cvtColor(output, cv2.COLOR_BGR2RGB), use_container_width=True)
st.success(f"Total: {len(results)} | Good: {good_count} | Bad: {bad_count}")