Spaces:
Sleeping
Sleeping
File size: 3,563 Bytes
73669d1 5d2bfa6 73669d1 5d2bfa6 fd44499 5d2bfa6 73669d1 5d2bfa6 73669d1 5d2bfa6 73669d1 5d2bfa6 73669d1 5d2bfa6 d8b04fd 5d2bfa6 73669d1 5d2bfa6 fd44499 5d2bfa6 73669d1 5d2bfa6 fd44499 de7cbe9 5d2bfa6 73669d1 5d2bfa6 73669d1 5d2bfa6 73669d1 5d2bfa6 73669d1 5d2bfa6 73669d1 5d2bfa6 afd5474 3e8dbea afd5474 fd44499 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 | from fastapi import FastAPI, File, UploadFile, HTTPException, Header, Depends
from fastapi.middleware.cors import CORSMiddleware
import numpy as np
import tensorflow as tf
import cv2
import base64
import os
import logging
from huggingface_hub import hf_hub_download
# ---------- CONFIG ----------
API_KEY = "your-secret-api-key" # Replace this with your actual key
IMG_SIZE = 256
CLASS_COLORS = {0: (0, 0, 0), 1: (0, 255, 0), 2: (0, 0, 255)}
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# ---------- API SETUP ----------
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
def verify_api_key(x_api_key: str = Header(...)):
if x_api_key != API_KEY:
raise HTTPException(status_code=403, detail="Invalid API Key")
# ---------- LOAD MODEL ----------
try:
os.environ["HF_HOME"] = "/tmp/huggingface" # Prevent permission issues on Spaces
model_path = hf_hub_download(
repo_id="rishab1090/potato",
filename="unet_tf.keras", # β
Use updated filename
cache_dir="/tmp/hf_cache" # β
Avoids FS write issues
)
model = tf.keras.models.load_model(model_path)
logger.info("β
Model loaded successfully from unet_tf.keras.")
except Exception as e:
logger.error(f"β Failed to load model: {e}")
raise RuntimeError(f"Model load failed: {e}")
# ---------- UTILS ----------
def decode_mask_to_overlay(image_bgr, mask):
overlay = image_bgr.copy()
for class_id, color in CLASS_COLORS.items():
overlay[mask == class_id] = (
np.array(overlay[mask == class_id]) * 0.5 + np.array(color) * 0.5
).astype(np.uint8)
return overlay
def image_to_base64(img: np.ndarray) -> str:
_, buffer = cv2.imencode('.png', img)
return base64.b64encode(buffer).decode("utf-8")
# ---------- PREDICTION ROUTE ----------
@app.post("/predict_severity")
async def predict_severity(
file: UploadFile = File(...),
x_api_key: str = Depends(verify_api_key)
):
try:
contents = await file.read()
file_bytes = np.frombuffer(contents, np.uint8)
img_bgr = cv2.imdecode(file_bytes, cv2.IMREAD_COLOR)
if img_bgr is None:
raise ValueError("Invalid image file")
img_resized = cv2.resize(img_bgr, (IMG_SIZE, IMG_SIZE))
img_norm = img_resized.astype(np.float32) / 255.0
img_input = np.expand_dims(img_norm, axis=0)
prediction = model.predict(img_input)[0]
mask = np.argmax(prediction, axis=-1).astype(np.uint8)
unique, counts = np.unique(mask, return_counts=True)
class_counts = {int(k): int(v) for k, v in zip(unique, counts)}
healthy = class_counts.get(1, 0)
diseased = class_counts.get(2, 0)
severity_percent = (diseased / (healthy + diseased)) * 100 if (healthy + diseased) > 0 else 0.0
overlay = decode_mask_to_overlay(img_resized, mask)
mask_base64 = image_to_base64(overlay)
return {
"severity": round(severity_percent, 2),
"class_counts": class_counts,
"segmentation_mask_base64": mask_base64
}
except Exception as e:
logger.error(f"Error during prediction: {e}")
raise HTTPException(status_code=500, detail=str(e))
if __name__ == "__main__":
import uvicorn
uvicorn.run("app:app", host="0.0.0.0", port=8000)
@app.get("/")
def read_root():
return {"status": "Server is running β
"}
|