gphua1's picture
Fix persistent screen shaking - disable Streamlit animations and stabilize layout
ea77165
"""
Production Defect Detection Application
Supports both API and Web Interface modes
"""
import os
import io
import sys
import base64
import json
import time
from pathlib import Path
from typing import Dict, Optional, Tuple
import argparse
import torch
import torch.nn.functional as F
import numpy as np
from PIL import Image
import cv2
import albumentations as A
from albumentations.pytorch import ToTensorV2
# Model imports
from models.vision_transformer import get_model
# Global model cache
_model_cache = {"model": None, "device": None, "transform": None}
def get_transform():
"""Get image preprocessing transform"""
return A.Compose([
A.Resize(224, 224),
A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
ToTensorV2()
])
def load_model(model_path: Optional[str] = None) -> Tuple[torch.nn.Module, torch.device, dict]:
"""Load model with caching and BF16 support"""
if _model_cache["model"] is not None:
return _model_cache["model"], _model_cache["device"], _model_cache.get("info", {})
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Find model file - prioritize BF16 model
if model_path is None:
model_paths = [
Path("models/best_model_bf16.pth"), # Prioritize BF16 model
Path("models/best_model.pth"),
Path("models/toy_model.pth")
]
for path in model_paths:
if path.exists():
model_path = str(path)
break
if not model_path or not Path(model_path).exists():
raise FileNotFoundError("No model found. Train a model first: python train.py")
# Check if it's a BF16 model
is_bf16 = "bf16" in str(model_path).lower()
# Load checkpoint
checkpoint = torch.load(model_path, map_location=device)
model_type = checkpoint.get('model_type', 'efficient_vit')
# Create and load model
model = get_model(model_type, num_classes=2, pretrained=False)
model.load_state_dict(checkpoint['model_state_dict'])
# Handle BF16 if needed
if is_bf16:
# Convert model to BF16 for inference
model = model.to(dtype=torch.bfloat16)
print(f"βœ… Loaded BF16 model from {model_path}")
model.to(device)
model.eval()
model_info = {
'model_type': model_type,
'accuracy': checkpoint.get('best_acc', checkpoint.get('accuracy', 0)),
'model_path': model_path,
'is_bf16': is_bf16
}
# Cache model
_model_cache["model"] = model
_model_cache["device"] = device
_model_cache["transform"] = get_transform()
_model_cache["info"] = model_info
_model_cache["is_bf16"] = is_bf16
return model, device, model_info
@torch.no_grad()
def predict_image(image: np.ndarray, model=None) -> Dict:
"""Predict defect in image with BF16 support"""
if model is None:
model, device, _ = load_model()
else:
device = next(model.parameters()).device
transform = _model_cache.get("transform") or get_transform()
is_bf16 = _model_cache.get("is_bf16", False)
# Preprocess
augmented = transform(image=image)
image_tensor = augmented['image'].unsqueeze(0).to(device)
# Convert to BF16 if model is BF16
if is_bf16:
image_tensor = image_tensor.to(dtype=torch.bfloat16)
# Inference
start_time = time.time()
outputs = model(image_tensor)
# Convert outputs to float32 for numerical stability in softmax
if outputs.dtype == torch.bfloat16:
outputs = outputs.float()
probs = F.softmax(outputs, dim=1)
confidence, predicted = torch.max(probs, 1)
inference_time = (time.time() - start_time) * 1000
return {
'prediction': 'DEFECTIVE' if predicted.item() == 1 else 'NORMAL',
'confidence': confidence.item(),
'defect_probability': probs[0][1].item(),
'normal_probability': probs[0][0].item(),
'inference_time': inference_time
}
def run_streamlit():
"""Run Streamlit web interface with clean Rocket Lab theme"""
import sys
import subprocess
import os
# If not running through streamlit, restart with streamlit
if "streamlit.runtime.scriptrunner" not in sys.modules:
print("πŸš€ Starting Rocket Lab Defect Detection System...")
print(" Opening browser to http://localhost:8501")
# Use the venv's streamlit executable
venv_streamlit = os.path.join(os.path.dirname(__file__), "venv", "bin", "streamlit")
if os.path.exists(venv_streamlit):
subprocess.run([venv_streamlit, "run", __file__])
else:
subprocess.run(["streamlit", "run", __file__])
return
import streamlit as st
# Rocket Lab themed configuration - with permanent sidebar
st.set_page_config(
page_title="RKLB Defect Detection",
page_icon="πŸš€",
layout="wide",
initial_sidebar_state="expanded" # Make sure sidebar is expanded
)
# Custom CSS with professional theme and permanent sidebar
st.markdown("""
<style>
/* Force sidebar to always be visible and expanded */
section[data-testid="stSidebar"] {
background: #0f0f0f !important;
border-right: 2px solid #333;
width: 21rem !important;
min-width: 21rem !important;
max-width: 21rem !important;
display: block !important;
position: relative !important;
left: 0 !important;
visibility: visible !important;
opacity: 1 !important;
transform: none !important;
}
/* Hide sidebar collapse button completely */
[data-testid="collapsedControl"] {
display: none !important;
}
button[kind="header"] {
display: none !important;
}
/* Hide the hamburger menu button */
[data-testid="baseButton-header"] {
display: none !important;
}
/* Ensure sidebar content is always visible */
section[data-testid="stSidebar"] > div {
display: block !important;
visibility: visible !important;
opacity: 1 !important;
}
section[data-testid="stSidebar"] > div:first-child {
padding-top: 2rem;
}
/* Clean dark background */
.stApp {
background: #0a0a0a;
}
/* Hide the link icon buttons next to headers */
[data-testid="StyledLinkIconContainer"] {
display: none !important;
}
/* Hide anchor links in headers */
.stMarkdown h1 a, .stMarkdown h2 a, .stMarkdown h3 a {
display: none !important;
}
/* Hide buttons that appear on hover for headers */
.element-container:has(.stMarkdown h1, .stMarkdown h2, .stMarkdown h3) button[kind="headerLink"] {
display: none !important;
}
/* Hide all header link anchors */
[data-testid="stHeaderActionElements"] {
display: none !important;
}
/* Hide copy buttons and link buttons */
.stMarkdown [data-testid="stCopyButton"],
.stMarkdown button[title*="link"] {
display: none !important;
}
/* Main header - Professional Rocket Lab style */
.main-header {
padding: 15px 0 20px 0;
border-bottom: 3px solid #dc2626;
margin-bottom: 30px;
background: linear-gradient(135deg, #1a1a1a 0%, #0f0f0f 100%);
margin-top: -3.5rem; /* Position header higher */
box-shadow: 0 4px 12px rgba(220, 38, 38, 0.1);
}
.main-header h1 {
color: #ffffff;
font-size: 1.9rem;
letter-spacing: 2px;
margin: 0;
font-weight: 700; /* Bold font weight */
text-align: left;
padding-left: 30px;
text-transform: uppercase;
}
.main-header .rocket-white {
color: #ffffff;
font-weight: 700; /* Bold for ROCKET LAB */
font-size: inherit;
letter-spacing: inherit;
}
.main-header .rocket-red {
color: #dc2626;
font-weight: 800; /* Extra bold for emphasis */
font-size: inherit;
letter-spacing: inherit;
}
.subtitle {
color: #aaa;
font-size: 0.8rem;
letter-spacing: 1.5px;
margin-top: 10px;
padding-left: 30px;
font-weight: 400;
text-transform: uppercase;
}
/* Sidebar styling */
.sidebar-header {
font-size: 1.1rem;
font-weight: 600;
color: #dc2626;
letter-spacing: 1.5px;
text-transform: uppercase;
margin-bottom: 15px;
padding-bottom: 10px;
border-bottom: 2px solid #333;
}
/* Sidebar sample images - professional and compact */
.sidebar-sample {
background: #1a1a1a;
border: 1px solid #333;
border-radius: 6px;
padding: 8px;
margin-bottom: 10px;
cursor: pointer;
transition: background 0.2s, border-color 0.2s, box-shadow 0.2s;
}
.sidebar-sample:hover {
border-color: #dc2626;
background: #1f1f1f;
/* Removed transform to prevent shaking */
}
.sample-label {
color: #bbb;
font-size: 0.75rem;
text-align: center;
margin-top: 8px;
margin-bottom: 5px;
text-transform: uppercase;
letter-spacing: 1px;
font-weight: 500;
}
/* Instruction button */
.instruction-btn {
background: transparent;
border: 1px solid #dc2626;
color: #dc2626;
padding: 8px 16px;
font-size: 0.85rem;
letter-spacing: 1px;
text-transform: uppercase;
border-radius: 4px;
cursor: pointer;
transition: background 0.2s, border-color 0.2s, box-shadow 0.2s;
margin-bottom: 15px;
}
.instruction-btn:hover {
background: #dc2626;
color: white;
}
/* Result box - fixed positioning to prevent shaking */
.result-box {
background: #1a1a1a;
border-radius: 8px;
padding: 30px;
margin: 20px 0;
text-align: center;
position: relative;
will-change: auto;
}
.result-pass {
border: 2px solid #10b981;
}
.result-fail {
border: 2px solid #dc2626;
}
.result-title {
font-size: 1.8rem;
margin: 0;
font-weight: 300;
}
.result-confidence {
font-size: 2.5rem;
margin: 15px 0;
font-weight: bold;
}
/* Metrics row - fixed positioning to prevent shaking */
.metrics-row {
display: flex;
justify-content: center;
gap: 40px;
margin: 20px 0;
position: relative;
will-change: auto;
}
.metric {
text-align: center;
}
.metric-label {
color: #888;
font-size: 0.8rem;
text-transform: uppercase;
letter-spacing: 1px;
}
.metric-value {
color: #ffffff;
font-size: 1.2rem;
font-weight: bold;
margin-top: 5px;
}
/* Upload area - more subtle and professional */
.upload-section {
background: #141414;
border: 1px solid #2a2a2a;
border-radius: 8px;
padding: 20px;
text-align: center;
margin: 15px 0;
}
/* Upload section header - smaller and professional */
.section-header {
font-size: 0.95rem;
font-weight: 600;
color: #ffffff;
text-transform: uppercase;
letter-spacing: 1.5px;
margin-bottom: 15px;
padding-bottom: 10px;
border-bottom: 1px solid #333;
}
/* Buttons - professional style */
.stButton > button {
background: linear-gradient(135deg, #dc2626, #b91c1c);
color: white;
border: none;
padding: 10px 24px;
font-size: 0.85rem;
font-weight: 600;
letter-spacing: 1.2px;
text-transform: uppercase;
border-radius: 4px;
width: 100%;
transition: background 0.2s, border-color 0.2s, box-shadow 0.2s;
box-shadow: 0 2px 8px rgba(220, 38, 38, 0.2);
}
.stButton > button:hover {
background: linear-gradient(135deg, #ef4444, #dc2626);
box-shadow: 0 4px 12px rgba(220, 38, 38, 0.3);
/* Removed transform to prevent shaking */
}
/* Prevent layout shifts and stabilize columns */
[data-testid="column"] {
transition: none !important;
transform: none !important;
}
[data-testid="stHorizontalBlock"] {
transition: none !important;
transform: none !important;
}
/* Disable all Streamlit element animations */
.stApp > div {
transition: none !important;
}
div[data-testid="stDecoration"] {
display: none !important;
}
/* Hide Streamlit branding */
#MainMenu {visibility: hidden;}
footer {visibility: hidden;}
header {visibility: hidden;}
/* Clean file uploader - professional styling */
[data-testid="stFileUploader"] {
background: transparent;
border: none;
}
[data-testid="stFileUploader"] label {
font-size: 0.85rem !important;
font-weight: 500 !important;
color: #999 !important;
text-transform: uppercase;
letter-spacing: 1px;
}
.uploadedFile {
background: #1a1a1a;
border: 1px solid #333;
border-radius: 4px;
padding: 8px;
}
/* Text colors and typography */
p, span, div {
color: #ffffff;
}
label {
color: #bbb;
font-weight: 500;
}
/* Streamlit section headers */
.stMarkdown h3 {
font-size: 0.95rem !important;
font-weight: 600 !important;
color: #ffffff !important;
text-transform: uppercase;
letter-spacing: 1.5px;
margin-bottom: 15px !important;
padding-bottom: 10px;
border-bottom: 1px solid #333;
}
/* Progress bars minimal */
.stProgress > div > div > div > div {
background: #dc2626;
height: 4px;
}
/* Status badge */
.status-badge {
display: inline-block;
padding: 4px 12px;
border-radius: 20px;
font-size: 0.8rem;
letter-spacing: 1px;
text-transform: uppercase;
font-weight: bold;
}
.status-pass {
background: #10b981;
color: white;
}
.status-fail {
background: #dc2626;
color: white;
}
</style>
""", unsafe_allow_html=True)
# Load model
try:
model, device, model_info = load_model()
except Exception as e:
st.error(f"Model Error: {e}")
st.stop()
# Header - Professional Rocket Lab style
st.markdown("""
<div class="main-header">
<h1><span class="rocket-white">ROCKET LAB</span> <span class="rocket-red">COMPONENT DEFECT DETECTION</span></h1>
<div class="subtitle">Made by Gary Phua</div>
</div>
""", unsafe_allow_html=True)
# Sidebar with sample images - ensure it's visible
with st.sidebar:
# Professional sidebar header
st.markdown("""
<div class="sidebar-header">Test Samples</div>
<div style='color: #999; font-size: 0.75rem; margin-bottom: 20px; text-transform: uppercase; letter-spacing: 1px;'>
Click to load sample image
</div>
""", unsafe_allow_html=True)
# Get example images
examples_dir = Path("examples")
sample_images = []
if examples_dir.exists():
normal_samples = sorted((examples_dir / "normal").glob("*.png"))
defect_samples = sorted((examples_dir / "defective").glob("*.png"))
# Select samples to ensure variety
if len(normal_samples) >= 1:
sample_images.append(normal_samples[0]) # First normal
if len(defect_samples) >= 1:
sample_images.append(defect_samples[0]) # Defective
if len(normal_samples) >= 2:
sample_images.append(normal_samples[-1]) # Last normal
# Display sample images in sidebar
if sample_images:
for idx, sample_path in enumerate(sample_images):
# Load and create small thumbnail
img = Image.open(sample_path)
img_thumbnail = img.resize((120, 120), Image.Resampling.LANCZOS)
# Professional label based on type
if "defect" in str(sample_path).lower():
label = f"Sample {idx + 1}"
else:
label = f"Sample {idx + 1}"
# Display in sidebar with improved layout
col1, col2 = st.columns([1, 2], gap="large")
with col1:
st.image(img_thumbnail, use_container_width=True)
with col2:
st.markdown(f"<div class='sample-label'>{label}</div>", unsafe_allow_html=True)
if st.button("Load", key=f"sample_{idx}", use_container_width=True):
st.session_state['selected_image'] = str(sample_path)
st.session_state['image_source'] = 'sample'
st.rerun()
# Main content area
main_container = st.container()
with main_container:
# Upload section with professional header
st.markdown("""
<div class="section-header">Upload Component Image</div>
""", unsafe_allow_html=True)
uploaded_file = st.file_uploader(
"Select image file (PNG, JPG, JPEG, BMP)",
type=['png', 'jpg', 'jpeg', 'bmp'],
label_visibility="visible"
)
# Process image
image = None
image_np = None
image_name = None
if uploaded_file:
image = Image.open(uploaded_file)
image_name = uploaded_file.name
st.session_state['image_source'] = 'upload'
st.session_state['selected_image'] = None
elif 'selected_image' in st.session_state and st.session_state['selected_image']:
image = Image.open(st.session_state['selected_image'])
image_name = Path(st.session_state['selected_image']).name
# Results section positioned below upload
if image:
image_np = np.array(image.convert('RGB'))
# Display image preview with professional header
st.markdown("""
<div class="section-header">Analysis Results</div>
""", unsafe_allow_html=True)
col1, col2 = st.columns([1, 2], gap="large")
with col1:
st.image(image, use_container_width=True)
with col2:
# Run prediction
with st.spinner("Analyzing component..."):
result = predict_image(image_np, model)
# Result display - professional layout
st.markdown("""
<div style="font-size: 0.9rem; font-weight: 600; color: #999; text-transform: uppercase; letter-spacing: 1px; margin-bottom: 15px;">Quality Assessment</div>
""", unsafe_allow_html=True)
if result['prediction'] == 'DEFECTIVE':
st.markdown("""
<div class="result-box result-fail">
<div class="status-badge status-fail">DEFECT DETECTED</div>
<div style="margin-top: 20px; color: #dc2626; font-size: 1.2rem;">Component Failed Quality Check</div>
</div>
""", unsafe_allow_html=True)
else:
st.markdown("""
<div class="result-box result-pass">
<div class="status-badge status-pass">PASSED</div>
<div style="margin-top: 20px; color: #10b981; font-size: 1.2rem;">Component Passed Quality Check</div>
</div>
""", unsafe_allow_html=True)
# Metrics
st.markdown("""
<div class="metrics-row">
<div class="metric">
<div class="metric-label">Confidence</div>
<div class="metric-value">{:.1f}%</div>
</div>
<div class="metric">
<div class="metric-label">Processing Time</div>
<div class="metric-value">{:.0f}ms</div>
</div>
</div>
""".format(result['confidence'] * 100, result['inference_time']), unsafe_allow_html=True)
else:
# Empty state - professional
st.markdown("""
<div style="text-align: center; padding: 80px 40px; background: #141414; border: 1px solid #2a2a2a; border-radius: 8px; margin-top: 40px;">
<div style="font-size: 1.1rem; margin-bottom: 15px; color: #999; font-weight: 600; text-transform: uppercase; letter-spacing: 1.5px;">Ready for Analysis</div>
<div style="font-size: 0.85rem; color: #666; line-height: 1.6;">Upload a component image or select a sample from the sidebar to begin quality inspection</div>
</div>
""", unsafe_allow_html=True)
def run_api():
"""Run FastAPI server"""
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import HTMLResponse
from pydantic import BaseModel
import uvicorn
app = FastAPI(
title="Defect Detection API",
version="1.0.0"
)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
class PredictionRequest(BaseModel):
image: str # base64 encoded
@app.on_event("startup")
async def startup():
try:
load_model()
print("βœ… Model loaded successfully")
except Exception as e:
print(f"❌ Model loading failed: {e}")
@app.get("/")
async def root():
return {
"message": "Defect Detection API",
"endpoints": {
"health": "/health",
"predict": "/predict",
"interface": "/interface"
}
}
@app.get("/health")
async def health():
return {"status": "healthy", "model_loaded": _model_cache["model"] is not None}
@app.post("/predict")
async def predict(request: PredictionRequest):
if _model_cache["model"] is None:
raise HTTPException(status_code=503, detail="Model not loaded")
try:
# Decode image
image_bytes = base64.b64decode(request.image)
image = Image.open(io.BytesIO(image_bytes))
image_np = np.array(image.convert('RGB'))
# Predict
result = predict_image(image_np, _model_cache["model"])
return result
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/interface")
async def interface():
html = """
<!DOCTYPE html>
<html>
<head>
<title>RKLB Defect Detection</title>
<style>
* { margin: 0; padding: 0; box-sizing: border-box; }
body {
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif;
background: #0a0a0a;
color: white;
min-height: 100vh;
padding: 40px 20px;
}
.container {
max-width: 900px;
margin: 0 auto;
}
h1 {
text-align: center;
font-size: 1.5rem;
font-weight: 300;
letter-spacing: 3px;
margin-bottom: 10px;
padding-bottom: 20px;
border-bottom: 2px solid #dc2626;
}
.subtitle {
text-align: center;
color: #999;
font-size: 0.7rem;
letter-spacing: 1px;
font-style: italic;
margin-bottom: 40px;
}
.upload-area {
border: 2px dashed #333;
padding: 40px;
text-align: center;
background: #1a1a1a;
border-radius: 8px;
margin: 30px 0;
}
.result {
margin: 30px 0;
padding: 30px;
border-radius: 8px;
background: #1a1a1a;
text-align: center;
}
.result-pass { border: 2px solid #10b981; }
.result-fail { border: 2px solid #dc2626; }
button {
background: #dc2626;
color: white;
padding: 10px 30px;
border: none;
border-radius: 4px;
cursor: pointer;
text-transform: uppercase;
letter-spacing: 1px;
font-size: 0.9rem;
}
button:hover { background: #b91c1c; }
#preview img {
max-width: 400px;
max-height: 400px;
margin: 20px auto;
display: block;
border: 1px solid #333;
border-radius: 8px;
}
.confidence {
font-size: 2.5rem;
font-weight: bold;
margin: 20px 0;
}
.status {
display: inline-block;
padding: 4px 12px;
border-radius: 20px;
font-size: 0.8rem;
text-transform: uppercase;
letter-spacing: 1px;
font-weight: bold;
margin-bottom: 10px;
}
.status-pass { background: #10b981; }
.status-fail { background: #dc2626; }
</style>
</head>
<body>
<div class="container">
<h1>ROCKET LAB <span style="color: #dc2626;">COMPONENT DEFECT DETECTION SYSTEM</span></h1>
<p class="subtitle">Made by Gary Phua</p>
<div class="upload-area">
<input type="file" id="imageInput" accept="image/*" style="margin-bottom: 20px;">
<br>
<button onclick="analyze()">Analyze Component</button>
</div>
<div id="preview"></div>
<div id="result"></div>
</div>
<script>
function analyze() {
const input = document.getElementById('imageInput');
const file = input.files[0];
if (!file) return alert('Select an image');
const reader = new FileReader();
reader.onload = e => {
document.getElementById('preview').innerHTML = '<img src="' + e.target.result + '">';
const base64 = e.target.result.split(',')[1];
document.getElementById('result').innerHTML = '<div class="result">Analyzing...</div>';
fetch('/predict', {
method: 'POST',
headers: {'Content-Type': 'application/json'},
body: JSON.stringify({image: base64})
})
.then(r => r.json())
.then(data => {
const passClass = data.prediction === 'DEFECTIVE' ? 'result-fail' : 'result-pass';
const statusClass = data.prediction === 'DEFECTIVE' ? 'status-fail' : 'status-pass';
const statusText = data.prediction === 'DEFECTIVE' ? 'DEFECT DETECTED' : 'PASSED';
document.getElementById('result').innerHTML =
'<div class="result ' + passClass + '">' +
'<div class="status ' + statusClass + '">' + statusText + '</div>' +
'<div class="confidence">' + (data.confidence * 100).toFixed(1) + '%</div>' +
'<div style="color: #888;">CONFIDENCE</div>' +
'<div style="margin-top: 20px; color: #888; font-size: 0.9rem;">' +
'Time: ' + data.inference_time.toFixed(0) + 'ms</div>' +
'</div>';
});
};
reader.readAsDataURL(file);
}
</script>
</body>
</html>
"""
return HTMLResponse(content=html)
# For Vercel deployment
if os.environ.get('VERCEL'):
return app
# Local server
uvicorn.run(app, host="0.0.0.0", port=8000)
def run_cli(args):
"""Run command-line interface"""
model, device, info = load_model(args.model)
print(f"βœ… Model loaded: {info['model_type']} (Acc: {info['accuracy']:.1f}%)")
if args.image:
# Single image prediction
image = cv2.imread(args.image)
if image is None:
print(f"❌ Cannot load image: {args.image}")
return
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
result = predict_image(image, model)
print(f"\nπŸ“· Image: {args.image}")
print(f"🎯 Prediction: {result['prediction']}")
print(f"πŸ“Š Confidence: {result['confidence']:.2%}")
print(f"⏱️ Inference: {result['inference_time']:.1f}ms")
elif args.directory:
# Batch prediction
from pathlib import Path
results = []
for img_path in Path(args.directory).glob("**/*"):
if img_path.suffix.lower() in ['.png', '.jpg', '.jpeg', '.bmp']:
image = cv2.imread(str(img_path))
if image is not None:
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
result = predict_image(image, model)
result['path'] = str(img_path)
results.append(result)
print(f"{'πŸ”΄' if result['prediction'] == 'DEFECTIVE' else '🟒'} {img_path.name}: {result['prediction']} ({result['confidence']:.1%})")
# Summary
defective = sum(1 for r in results if r['prediction'] == 'DEFECTIVE')
print(f"\nπŸ“Š Results: {defective}/{len(results)} defective ({defective/len(results)*100:.1f}%)")
if args.output:
with open(args.output, 'w') as f:
json.dump(results, f, indent=2)
print(f"πŸ’Ύ Saved to {args.output}")
def main():
# Check if running through streamlit
import sys
if "streamlit.runtime.scriptrunner" in sys.modules:
run_streamlit()
return
parser = argparse.ArgumentParser(description='Defect Detection Application')
parser.add_argument('--mode', choices=['web', 'api', 'cli'], default='web',
help='Run mode: web (Streamlit), api (FastAPI), or cli')
parser.add_argument('--model', type=str, help='Model path')
parser.add_argument('--image', type=str, help='Single image path (CLI mode)')
parser.add_argument('--directory', type=str, help='Directory of images (CLI mode)')
parser.add_argument('--output', type=str, help='Save results to JSON (CLI mode)')
args = parser.parse_args()
if args.mode == 'web':
run_streamlit()
elif args.mode == 'api':
run_api()
else:
run_cli(args)
# For Vercel deployment
app = None
if os.environ.get('VERCEL'):
from fastapi import FastAPI
# Return the FastAPI app for Vercel
app = run_api()
if __name__ == "__main__":
main()