Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, File, UploadFile, HTTPException, Header, Depends | |
| from fastapi.middleware.cors import CORSMiddleware | |
| import numpy as np | |
| import tensorflow as tf | |
| import cv2 | |
| import base64 | |
| import os | |
| import logging | |
| from huggingface_hub import hf_hub_download | |
| # ---------- CONFIG ---------- | |
| API_KEY = "your-secret-api-key" # Replace this with your actual key | |
| IMG_SIZE = 256 | |
| CLASS_COLORS = {0: (0, 0, 0), 1: (0, 255, 0), 2: (0, 0, 255)} | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # ---------- API SETUP ---------- | |
| app = FastAPI() | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| def verify_api_key(x_api_key: str = Header(...)): | |
| if x_api_key != API_KEY: | |
| raise HTTPException(status_code=403, detail="Invalid API Key") | |
| # ---------- LOAD MODEL ---------- | |
| try: | |
| os.environ["HF_HOME"] = "/tmp/huggingface" # Prevent permission issues on Spaces | |
| model_path = hf_hub_download( | |
| repo_id="rishab1090/potato", | |
| filename="unet_tf.keras", # β Use updated filename | |
| cache_dir="/tmp/hf_cache" # β Avoids FS write issues | |
| ) | |
| model = tf.keras.models.load_model(model_path) | |
| logger.info("β Model loaded successfully from unet_tf.keras.") | |
| except Exception as e: | |
| logger.error(f"β Failed to load model: {e}") | |
| raise RuntimeError(f"Model load failed: {e}") | |
| # ---------- UTILS ---------- | |
| def decode_mask_to_overlay(image_bgr, mask): | |
| overlay = image_bgr.copy() | |
| for class_id, color in CLASS_COLORS.items(): | |
| overlay[mask == class_id] = ( | |
| np.array(overlay[mask == class_id]) * 0.5 + np.array(color) * 0.5 | |
| ).astype(np.uint8) | |
| return overlay | |
| def image_to_base64(img: np.ndarray) -> str: | |
| _, buffer = cv2.imencode('.png', img) | |
| return base64.b64encode(buffer).decode("utf-8") | |
| # ---------- PREDICTION ROUTE ---------- | |
| async def predict_severity( | |
| file: UploadFile = File(...), | |
| x_api_key: str = Depends(verify_api_key) | |
| ): | |
| try: | |
| contents = await file.read() | |
| file_bytes = np.frombuffer(contents, np.uint8) | |
| img_bgr = cv2.imdecode(file_bytes, cv2.IMREAD_COLOR) | |
| if img_bgr is None: | |
| raise ValueError("Invalid image file") | |
| img_resized = cv2.resize(img_bgr, (IMG_SIZE, IMG_SIZE)) | |
| img_norm = img_resized.astype(np.float32) / 255.0 | |
| img_input = np.expand_dims(img_norm, axis=0) | |
| prediction = model.predict(img_input)[0] | |
| mask = np.argmax(prediction, axis=-1).astype(np.uint8) | |
| unique, counts = np.unique(mask, return_counts=True) | |
| class_counts = {int(k): int(v) for k, v in zip(unique, counts)} | |
| healthy = class_counts.get(1, 0) | |
| diseased = class_counts.get(2, 0) | |
| severity_percent = (diseased / (healthy + diseased)) * 100 if (healthy + diseased) > 0 else 0.0 | |
| overlay = decode_mask_to_overlay(img_resized, mask) | |
| mask_base64 = image_to_base64(overlay) | |
| return { | |
| "severity": round(severity_percent, 2), | |
| "class_counts": class_counts, | |
| "segmentation_mask_base64": mask_base64 | |
| } | |
| except Exception as e: | |
| logger.error(f"Error during prediction: {e}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run("app:app", host="0.0.0.0", port=8000) | |
| def read_root(): | |
| return {"status": "Server is running β "} | |