Spaces:
Running
Running
| from fastapi import FastAPI, File, UploadFile, HTTPException | |
| from tensorflow.keras.applications.mobilenet_v2 import preprocess_input | |
| from fastapi.staticfiles import StaticFiles | |
| from fastapi.templating import Jinja2Templates | |
| from fastapi.requests import Request | |
| from fastapi.responses import JSONResponse | |
| import numpy as np | |
| from PIL import Image | |
| import io | |
| import os | |
| import json | |
| import tensorflow as tf | |
| app = FastAPI(title="PCB Defect Detection API") | |
| # Mount static files and templates | |
| app.mount("/static", StaticFiles(directory="static"), name="static") | |
| templates = Jinja2Templates(directory="templates") | |
| # ── Load model on startup ───────────────────────────────────────────────────── | |
| MODEL_PATH = "pcb_model.keras" | |
| CLASS_PATH = "class_names.json" | |
| IMG_SIZE = (224, 224) | |
| model = None | |
| class_names = {} | |
| def build_model(): | |
| from tensorflow.keras.applications import MobileNetV2 | |
| from tensorflow.keras.layers import Dense, GlobalAveragePooling2D, Dropout, BatchNormalization | |
| from tensorflow.keras.models import Model | |
| base = MobileNetV2(input_shape=(224, 224, 3), include_top=False, weights=None) | |
| x = GlobalAveragePooling2D()(base.output) | |
| x = BatchNormalization()(x) | |
| x = Dense(512, activation="relu")(x) | |
| x = Dropout(0.4)(x) | |
| x = Dense(256, activation="relu")(x) | |
| x = Dropout(0.3)(x) | |
| out = Dense(6, activation="softmax")(x) | |
| return Model(base.input, out) | |
| async def load_model(): | |
| global model, class_names | |
| weights_path = "pcb_weights.weights.h5" | |
| if os.path.exists(weights_path): | |
| model = build_model() | |
| model.load_weights(weights_path) | |
| print("✅ Model loaded successfully") | |
| else: | |
| print("⚠️ Model not found — using demo mode") | |
| if os.path.exists(CLASS_PATH): | |
| with open(CLASS_PATH) as f: | |
| class_names = json.load(f) | |
| else: | |
| class_names = { | |
| "0": "missing_hole", | |
| "1": "mouse_bite", | |
| "2": "open_circuit", | |
| "3": "short", | |
| "4": "spur", | |
| "5": "spurious_copper" | |
| } | |
| # ── Defect descriptions ─────────────────────────────────────────────────────── | |
| DEFECT_INFO = { | |
| "missing_hole": { | |
| "label": "Missing Hole", | |
| "description": "A required drill hole is absent from the PCB. This prevents component mounting and causes assembly failure.", | |
| "severity": "High", | |
| "color": "#FF4444" | |
| }, | |
| "mouse_bite": { | |
| "label": "Mouse Bite", | |
| "description": "Small notches or indentations along the PCB edge, resembling bite marks. Usually caused by routing errors.", | |
| "severity": "Medium", | |
| "color": "#FF8C00" | |
| }, | |
| "open_circuit": { | |
| "label": "Open Circuit", | |
| "description": "A broken trace or gap in the copper path that interrupts electrical continuity.", | |
| "severity": "High", | |
| "color": "#FF4444" | |
| }, | |
| "short": { | |
| "label": "Short Circuit", | |
| "description": "Unintended connection between two copper traces, causing electrical short that can damage components.", | |
| "severity": "Critical", | |
| "color": "#CC0000" | |
| }, | |
| "spur": { | |
| "label": "Spur", | |
| "description": "An unwanted protrusion on a copper trace. Can cause unintended connections with adjacent traces.", | |
| "severity": "Low", | |
| "color": "#28A745" | |
| }, | |
| "spurious_copper": { | |
| "label": "Spurious Copper", | |
| "description": "Extra copper remaining on the board after etching. Can cause short circuits if near other traces.", | |
| "severity": "Medium", | |
| "color": "#FF8C00" | |
| } | |
| } | |
| def preprocess_image(image_bytes: bytes) -> np.ndarray: | |
| img = Image.open(io.BytesIO(image_bytes)).convert("RGB") | |
| img = img.resize(IMG_SIZE) | |
| arr = np.array(img, dtype=np.float32) | |
| arr = preprocess_input(arr) | |
| return np.expand_dims(arr, axis=0) | |
| async def home(request: Request): | |
| return templates.TemplateResponse("index.html", {"request": request}) | |
| async def predict(file: UploadFile = File(...)): | |
| # Validate file type | |
| if not file.content_type.startswith("image/"): | |
| raise HTTPException(status_code=400, detail="File must be an image.") | |
| contents = await file.read() | |
| if len(contents) > 10 * 1024 * 1024: | |
| raise HTTPException(status_code=400, detail="Image too large. Max 10MB.") | |
| try: | |
| img_array = preprocess_image(contents) | |
| except Exception: | |
| raise HTTPException(status_code=400, detail="Could not process image.") | |
| if model is None: | |
| # Demo mode — return mock prediction | |
| import random | |
| classes = list(DEFECT_INFO.keys()) | |
| detected = random.choice(classes) | |
| confidence = round(random.uniform(0.75, 0.97), 4) | |
| top3 = random.sample(classes, 3) | |
| top3_scores = sorted([round(random.uniform(0.01, 0.3), 4) for _ in top3], reverse=True) | |
| top3_results = [{"class": c, "confidence": round(s * 100, 2)} for c, s in zip(top3, top3_scores)] | |
| top3_results[0] = {"class": detected, "confidence": round(confidence * 100, 2)} | |
| else: | |
| preds = model.predict(img_array)[0] | |
| top_idx = int(np.argmax(preds)) | |
| confidence = float(preds[top_idx]) | |
| detected = class_names.get(str(top_idx), "unknown") | |
| top3_idx = np.argsort(preds)[::-1][:3] | |
| top3_results = [ | |
| {"class": class_names.get(str(i), "unknown"), "confidence": round(float(preds[i]) * 100, 2)} | |
| for i in top3_idx | |
| ] | |
| info = DEFECT_INFO.get(detected, { | |
| "label": detected.replace("_", " ").title(), | |
| "description": "Defect detected on the PCB surface.", | |
| "severity": "Unknown", | |
| "color": "#666" | |
| }) | |
| return JSONResponse({ | |
| "success": True, | |
| "defect": detected, | |
| "label": info["label"], | |
| "confidence": round(confidence * 100, 2), | |
| "severity": info["severity"], | |
| "color": info["color"], | |
| "description":info["description"], | |
| "top3": top3_results | |
| }) | |
| async def health(): | |
| return {"status": "ok", "model_loaded": model is not None} | |