rishab1090 commited on
Commit
5d2bfa6
Β·
verified Β·
1 Parent(s): de7cbe9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +36 -28
app.py CHANGED
@@ -1,48 +1,54 @@
1
  from fastapi import FastAPI, File, UploadFile, HTTPException, Header, Depends
2
  from fastapi.middleware.cors import CORSMiddleware
3
- from fastapi.responses import JSONResponse
4
  import numpy as np
5
  import tensorflow as tf
6
  import cv2
7
  import base64
 
 
 
8
 
9
- # ---- Add this for API key protection ----
10
- API_KEY = "your-secret-api-key" # πŸ” Replace with your actual key
 
 
11
 
12
- def verify_api_key(x_api_key: str = Header(...)):
13
- if x_api_key != API_KEY:
14
- raise HTTPException(status_code=403, detail="Invalid API Key")
15
- # -----------------------------------------
16
 
 
17
  app = FastAPI()
18
 
19
  app.add_middleware(
20
  CORSMiddleware,
21
- allow_origins=["*"], # In production: allow only trusted domains
22
  allow_credentials=True,
23
  allow_methods=["*"],
24
  allow_headers=["*"],
25
  )
26
- from tensorflow.keras.models import load_model
27
 
28
- # Available backend options are: "jax", "torch", "tensorflow".
29
- import os
30
- os.environ["HF_HOME"] = "/tmp/huggingface" # βœ… Place this before hf_hub_download
31
- from huggingface_hub import hf_hub_download
32
 
33
- # Make sure HF_HOME is set before using huggingface_hub
34
- os.environ["HF_HOME"] = "/tmp/huggingface"
 
35
 
36
- model_path = hf_hub_download(
37
- repo_id="rishab1090/potato",
38
- filename="unet_model.keras"
39
- )
40
- model = tf.keras.models.load_model(model_path)
41
 
 
 
42
 
43
- IMG_SIZE = 256
44
- CLASS_COLORS = {0: (0, 0, 0), 1: (0, 255, 0), 2: (0, 0, 255)}
 
45
 
 
46
  def decode_mask_to_overlay(image_bgr, mask):
47
  overlay = image_bgr.copy()
48
  for class_id, color in CLASS_COLORS.items():
@@ -55,16 +61,20 @@ def image_to_base64(img: np.ndarray) -> str:
55
  _, buffer = cv2.imencode('.png', img)
56
  return base64.b64encode(buffer).decode("utf-8")
57
 
58
-
59
  @app.post("/predict_severity")
60
  async def predict_severity(
61
  file: UploadFile = File(...),
62
- x_api_key: str = Depends(verify_api_key) # πŸ” Require API key
63
  ):
64
  try:
65
  contents = await file.read()
66
  file_bytes = np.frombuffer(contents, np.uint8)
67
  img_bgr = cv2.imdecode(file_bytes, cv2.IMREAD_COLOR)
 
 
 
 
68
  img_resized = cv2.resize(img_bgr, (IMG_SIZE, IMG_SIZE))
69
  img_norm = img_resized.astype(np.float32) / 255.0
70
  img_input = np.expand_dims(img_norm, axis=0)
@@ -72,9 +82,6 @@ async def predict_severity(
72
  prediction = model.predict(img_input)[0]
73
  mask = np.argmax(prediction, axis=-1).astype(np.uint8)
74
 
75
- center_pixel = prediction[IMG_SIZE // 2, IMG_SIZE // 2]
76
- print(f"Center pixel confidence: {center_pixel}")
77
-
78
  unique, counts = np.unique(mask, return_counts=True)
79
  class_counts = {int(k): int(v) for k, v in zip(unique, counts)}
80
 
@@ -92,4 +99,5 @@ async def predict_severity(
92
  }
93
 
94
  except Exception as e:
95
- raise HTTPException(status_code=500, detail=str(e))
 
 
1
  from fastapi import FastAPI, File, UploadFile, HTTPException, Header, Depends
2
  from fastapi.middleware.cors import CORSMiddleware
 
3
  import numpy as np
4
  import tensorflow as tf
5
  import cv2
6
  import base64
7
+ import os
8
+ import logging
9
+ from huggingface_hub import hf_hub_download
10
 
11
+ # ---------- CONFIG ----------
12
+ API_KEY = "your-secret-api-key" # Replace this with your key
13
+ IMG_SIZE = 256
14
+ CLASS_COLORS = {0: (0, 0, 0), 1: (0, 255, 0), 2: (0, 0, 255)}
15
 
16
+ logging.basicConfig(level=logging.INFO)
17
+ logger = logging.getLogger(__name__)
 
 
18
 
19
+ # ---------- API SETUP ----------
20
  app = FastAPI()
21
 
22
  app.add_middleware(
23
  CORSMiddleware,
24
+ allow_origins=["*"],
25
  allow_credentials=True,
26
  allow_methods=["*"],
27
  allow_headers=["*"],
28
  )
 
29
 
30
+ def verify_api_key(x_api_key: str = Header(...)):
31
+ if x_api_key != API_KEY:
32
+ raise HTTPException(status_code=403, detail="Invalid API Key")
 
33
 
34
+ # ---------- LOAD MODEL ----------
35
+ try:
36
+ os.environ["HF_HOME"] = "/tmp/huggingface" # Prevent permission issues on Spaces
37
 
38
+ model_path = hf_hub_download(
39
+ repo_id="rishab1090/potato",
40
+ filename="unet_model.keras", # πŸ‘ˆ Use the exact filename
41
+ cache_dir="/tmp/hf_cache" # πŸ‘ˆ Helps avoid read-only FS errors
42
+ )
43
 
44
+ model = tf.keras.models.load_model(model_path)
45
+ logger.info("βœ… Model loaded successfully from .keras file.")
46
 
47
+ except Exception as e:
48
+ logger.error(f"❌ Failed to load model: {e}")
49
+ raise RuntimeError(f"Model load failed: {e}")
50
 
51
+ # ---------- UTILS ----------
52
  def decode_mask_to_overlay(image_bgr, mask):
53
  overlay = image_bgr.copy()
54
  for class_id, color in CLASS_COLORS.items():
 
61
  _, buffer = cv2.imencode('.png', img)
62
  return base64.b64encode(buffer).decode("utf-8")
63
 
64
+ # ---------- PREDICTION ROUTE ----------
65
  @app.post("/predict_severity")
66
  async def predict_severity(
67
  file: UploadFile = File(...),
68
+ x_api_key: str = Depends(verify_api_key)
69
  ):
70
  try:
71
  contents = await file.read()
72
  file_bytes = np.frombuffer(contents, np.uint8)
73
  img_bgr = cv2.imdecode(file_bytes, cv2.IMREAD_COLOR)
74
+
75
+ if img_bgr is None:
76
+ raise ValueError("Invalid image file")
77
+
78
  img_resized = cv2.resize(img_bgr, (IMG_SIZE, IMG_SIZE))
79
  img_norm = img_resized.astype(np.float32) / 255.0
80
  img_input = np.expand_dims(img_norm, axis=0)
 
82
  prediction = model.predict(img_input)[0]
83
  mask = np.argmax(prediction, axis=-1).astype(np.uint8)
84
 
 
 
 
85
  unique, counts = np.unique(mask, return_counts=True)
86
  class_counts = {int(k): int(v) for k, v in zip(unique, counts)}
87
 
 
99
  }
100
 
101
  except Exception as e:
102
+ logger.error(f"Error during prediction: {e}")
103
+ raise HTTPException(status_code=500, detail=str(e))