potato2 / app.py
rishab1090's picture
Update app.py
3e8dbea verified
from fastapi import FastAPI, File, UploadFile, HTTPException, Header, Depends
from fastapi.middleware.cors import CORSMiddleware
import numpy as np
import tensorflow as tf
import cv2
import base64
import os
import logging
from huggingface_hub import hf_hub_download
# ---------- CONFIG ----------
API_KEY = "your-secret-api-key" # Replace this with your actual key
IMG_SIZE = 256
CLASS_COLORS = {0: (0, 0, 0), 1: (0, 255, 0), 2: (0, 0, 255)}
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# ---------- API SETUP ----------
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
def verify_api_key(x_api_key: str = Header(...)):
if x_api_key != API_KEY:
raise HTTPException(status_code=403, detail="Invalid API Key")
# ---------- LOAD MODEL ----------
try:
os.environ["HF_HOME"] = "/tmp/huggingface" # Prevent permission issues on Spaces
model_path = hf_hub_download(
repo_id="rishab1090/potato",
filename="unet_tf.keras", # βœ… Use updated filename
cache_dir="/tmp/hf_cache" # βœ… Avoids FS write issues
)
model = tf.keras.models.load_model(model_path)
logger.info("βœ… Model loaded successfully from unet_tf.keras.")
except Exception as e:
logger.error(f"❌ Failed to load model: {e}")
raise RuntimeError(f"Model load failed: {e}")
# ---------- UTILS ----------
def decode_mask_to_overlay(image_bgr, mask):
overlay = image_bgr.copy()
for class_id, color in CLASS_COLORS.items():
overlay[mask == class_id] = (
np.array(overlay[mask == class_id]) * 0.5 + np.array(color) * 0.5
).astype(np.uint8)
return overlay
def image_to_base64(img: np.ndarray) -> str:
_, buffer = cv2.imencode('.png', img)
return base64.b64encode(buffer).decode("utf-8")
# ---------- PREDICTION ROUTE ----------
@app.post("/predict_severity")
async def predict_severity(
file: UploadFile = File(...),
x_api_key: str = Depends(verify_api_key)
):
try:
contents = await file.read()
file_bytes = np.frombuffer(contents, np.uint8)
img_bgr = cv2.imdecode(file_bytes, cv2.IMREAD_COLOR)
if img_bgr is None:
raise ValueError("Invalid image file")
img_resized = cv2.resize(img_bgr, (IMG_SIZE, IMG_SIZE))
img_norm = img_resized.astype(np.float32) / 255.0
img_input = np.expand_dims(img_norm, axis=0)
prediction = model.predict(img_input)[0]
mask = np.argmax(prediction, axis=-1).astype(np.uint8)
unique, counts = np.unique(mask, return_counts=True)
class_counts = {int(k): int(v) for k, v in zip(unique, counts)}
healthy = class_counts.get(1, 0)
diseased = class_counts.get(2, 0)
severity_percent = (diseased / (healthy + diseased)) * 100 if (healthy + diseased) > 0 else 0.0
overlay = decode_mask_to_overlay(img_resized, mask)
mask_base64 = image_to_base64(overlay)
return {
"severity": round(severity_percent, 2),
"class_counts": class_counts,
"segmentation_mask_base64": mask_base64
}
except Exception as e:
logger.error(f"Error during prediction: {e}")
raise HTTPException(status_code=500, detail=str(e))
if __name__ == "__main__":
import uvicorn
uvicorn.run("app:app", host="0.0.0.0", port=8000)
@app.get("/")
def read_root():
return {"status": "Server is running βœ…"}