File size: 3,563 Bytes
73669d1
 
 
 
 
 
5d2bfa6
 
 
73669d1
5d2bfa6
fd44499
5d2bfa6
 
73669d1
5d2bfa6
 
73669d1
5d2bfa6
73669d1
 
 
 
5d2bfa6
73669d1
 
 
 
 
5d2bfa6
 
 
d8b04fd
5d2bfa6
 
 
73669d1
5d2bfa6
 
fd44499
 
5d2bfa6
73669d1
5d2bfa6
fd44499
de7cbe9
5d2bfa6
 
 
73669d1
5d2bfa6
73669d1
 
 
 
 
 
 
 
 
 
 
 
5d2bfa6
73669d1
 
 
5d2bfa6
73669d1
 
 
 
 
5d2bfa6
 
 
 
73669d1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5d2bfa6
 
afd5474
 
 
3e8dbea
 
 
afd5474
fd44499
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
from fastapi import FastAPI, File, UploadFile, HTTPException, Header, Depends
from fastapi.middleware.cors import CORSMiddleware
import numpy as np
import tensorflow as tf
import cv2
import base64
import os
import logging
from huggingface_hub import hf_hub_download

# ---------- CONFIG ----------
API_KEY = "your-secret-api-key"  # Replace this with your actual key
IMG_SIZE = 256
CLASS_COLORS = {0: (0, 0, 0), 1: (0, 255, 0), 2: (0, 0, 255)}

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# ---------- API SETUP ----------
app = FastAPI()

app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

def verify_api_key(x_api_key: str = Header(...)):
    if x_api_key != API_KEY:
        raise HTTPException(status_code=403, detail="Invalid API Key")

# ---------- LOAD MODEL ----------
try:
    os.environ["HF_HOME"] = "/tmp/huggingface"  # Prevent permission issues on Spaces

    model_path = hf_hub_download(
        repo_id="rishab1090/potato",
        filename="unet_tf.keras",           # βœ… Use updated filename
        cache_dir="/tmp/hf_cache"           # βœ… Avoids FS write issues
    )

    model = tf.keras.models.load_model(model_path)
    logger.info("βœ… Model loaded successfully from unet_tf.keras.")

except Exception as e:
    logger.error(f"❌ Failed to load model: {e}")
    raise RuntimeError(f"Model load failed: {e}")

# ---------- UTILS ----------
def decode_mask_to_overlay(image_bgr, mask):
    overlay = image_bgr.copy()
    for class_id, color in CLASS_COLORS.items():
        overlay[mask == class_id] = (
            np.array(overlay[mask == class_id]) * 0.5 + np.array(color) * 0.5
        ).astype(np.uint8)
    return overlay

def image_to_base64(img: np.ndarray) -> str:
    _, buffer = cv2.imencode('.png', img)
    return base64.b64encode(buffer).decode("utf-8")

# ---------- PREDICTION ROUTE ----------
@app.post("/predict_severity")
async def predict_severity(
    file: UploadFile = File(...),
    x_api_key: str = Depends(verify_api_key)
):
    try:
        contents = await file.read()
        file_bytes = np.frombuffer(contents, np.uint8)
        img_bgr = cv2.imdecode(file_bytes, cv2.IMREAD_COLOR)

        if img_bgr is None:
            raise ValueError("Invalid image file")

        img_resized = cv2.resize(img_bgr, (IMG_SIZE, IMG_SIZE))
        img_norm = img_resized.astype(np.float32) / 255.0
        img_input = np.expand_dims(img_norm, axis=0)

        prediction = model.predict(img_input)[0]
        mask = np.argmax(prediction, axis=-1).astype(np.uint8)

        unique, counts = np.unique(mask, return_counts=True)
        class_counts = {int(k): int(v) for k, v in zip(unique, counts)}

        healthy = class_counts.get(1, 0)
        diseased = class_counts.get(2, 0)
        severity_percent = (diseased / (healthy + diseased)) * 100 if (healthy + diseased) > 0 else 0.0

        overlay = decode_mask_to_overlay(img_resized, mask)
        mask_base64 = image_to_base64(overlay)

        return {
            "severity": round(severity_percent, 2),
            "class_counts": class_counts,
            "segmentation_mask_base64": mask_base64
        }

    except Exception as e:
        logger.error(f"Error during prediction: {e}")
        raise HTTPException(status_code=500, detail=str(e))
if __name__ == "__main__":
    import uvicorn
    uvicorn.run("app:app", host="0.0.0.0", port=8000)
@app.get("/")
def read_root():
    return {"status": "Server is running βœ…"}