from fastapi import FastAPI, File, UploadFile, HTTPException, Header, Depends from fastapi.middleware.cors import CORSMiddleware import numpy as np import tensorflow as tf import cv2 import base64 import os import logging from huggingface_hub import hf_hub_download # ---------- CONFIG ---------- API_KEY = "your-secret-api-key" # Replace this with your actual key IMG_SIZE = 256 CLASS_COLORS = {0: (0, 0, 0), 1: (0, 255, 0), 2: (0, 0, 255)} logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # ---------- API SETUP ---------- app = FastAPI() app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) def verify_api_key(x_api_key: str = Header(...)): if x_api_key != API_KEY: raise HTTPException(status_code=403, detail="Invalid API Key") # ---------- LOAD MODEL ---------- try: os.environ["HF_HOME"] = "/tmp/huggingface" # Prevent permission issues on Spaces model_path = hf_hub_download( repo_id="rishab1090/potato", filename="unet_tf.keras", # ✅ Use updated filename cache_dir="/tmp/hf_cache" # ✅ Avoids FS write issues ) model = tf.keras.models.load_model(model_path) logger.info("✅ Model loaded successfully from unet_tf.keras.") except Exception as e: logger.error(f"❌ Failed to load model: {e}") raise RuntimeError(f"Model load failed: {e}") # ---------- UTILS ---------- def decode_mask_to_overlay(image_bgr, mask): overlay = image_bgr.copy() for class_id, color in CLASS_COLORS.items(): overlay[mask == class_id] = ( np.array(overlay[mask == class_id]) * 0.5 + np.array(color) * 0.5 ).astype(np.uint8) return overlay def image_to_base64(img: np.ndarray) -> str: _, buffer = cv2.imencode('.png', img) return base64.b64encode(buffer).decode("utf-8") # ---------- PREDICTION ROUTE ---------- @app.post("/predict_severity") async def predict_severity( file: UploadFile = File(...), x_api_key: str = Depends(verify_api_key) ): try: contents = await file.read() file_bytes = np.frombuffer(contents, np.uint8) img_bgr = cv2.imdecode(file_bytes, cv2.IMREAD_COLOR) if img_bgr is None: raise ValueError("Invalid image file") img_resized = cv2.resize(img_bgr, (IMG_SIZE, IMG_SIZE)) img_norm = img_resized.astype(np.float32) / 255.0 img_input = np.expand_dims(img_norm, axis=0) prediction = model.predict(img_input)[0] mask = np.argmax(prediction, axis=-1).astype(np.uint8) unique, counts = np.unique(mask, return_counts=True) class_counts = {int(k): int(v) for k, v in zip(unique, counts)} healthy = class_counts.get(1, 0) diseased = class_counts.get(2, 0) severity_percent = (diseased / (healthy + diseased)) * 100 if (healthy + diseased) > 0 else 0.0 overlay = decode_mask_to_overlay(img_resized, mask) mask_base64 = image_to_base64(overlay) return { "severity": round(severity_percent, 2), "class_counts": class_counts, "segmentation_mask_base64": mask_base64 } except Exception as e: logger.error(f"Error during prediction: {e}") raise HTTPException(status_code=500, detail=str(e)) if __name__ == "__main__": import uvicorn uvicorn.run("app:app", host="0.0.0.0", port=8000) @app.get("/") def read_root(): return {"status": "Server is running ✅"}