warresnaet commited on
Commit
b11554b
·
verified ·
1 Parent(s): 154ab55

Reuse local main.py + numpy 1.26.4

Browse files
Files changed (2) hide show
  1. main.py +82 -72
  2. requirements.txt +1 -1
main.py CHANGED
@@ -1,3 +1,7 @@
 
 
 
 
1
  import os
2
  import logging
3
  from typing import List
@@ -6,111 +10,117 @@ import numpy as np
6
  from PIL import Image
7
 
8
  import tensorflow as tf
9
- from fastapi import FastAPI, File, UploadFile, HTTPException
10
- from fastapi.middleware.cors import CORSMiddleware
11
- from fastapi.responses import JSONResponse
12
-
13
  from huggingface_hub import hf_hub_download
14
 
15
 
16
- # -----------------------
17
- # Config
18
- # -----------------------
19
  logging.basicConfig(level=logging.INFO)
20
- logger = logging.getLogger("hf-space-fastapi")
21
-
22
- HF_REPO_ID = os.environ.get("HF_REPO_ID", "warresnaet/masterclass-2025")
23
- HF_MODEL_FILENAME = os.environ.get("HF_MODEL_FILENAME", "model.keras")
24
- HF_REVISION = os.environ.get("HF_REVISION", "main")
25
- LOCAL_MODEL_DIR = os.environ.get("LOCAL_MODEL_DIR", "./model")
26
-
27
- ANIMALS: List[str] = ["Cat", "Dog", "Panda"]
28
 
29
 
30
- # -----------------------
31
- # App
32
- # -----------------------
33
- app = FastAPI(title="Animal Classification API (HF Space)")
34
 
35
  app.add_middleware(
36
  CORSMiddleware,
37
  allow_origins=["*"],
 
38
  allow_methods=["*"],
39
  allow_headers=["*"],
40
- allow_credentials=True,
41
  )
42
 
43
 
44
- def _ensure_model_local() -> str:
45
- """Download model file from HF if missing; return the absolute path to the model file."""
46
- os.makedirs(LOCAL_MODEL_DIR, exist_ok=True)
47
- local_model_path = os.path.join(LOCAL_MODEL_DIR, HF_MODEL_FILENAME)
48
- if os.path.exists(local_model_path):
49
- return os.path.abspath(local_model_path)
50
-
51
- logger.info(
52
- f"Downloading model from HF Hub: repo_id={HF_REPO_ID}, "
53
- f"filename={HF_MODEL_FILENAME}, revision={HF_REVISION}"
54
- )
55
- downloaded = hf_hub_download(
56
- repo_id=HF_REPO_ID,
57
- filename=HF_MODEL_FILENAME,
58
- repo_type="model",
59
- revision=HF_REVISION,
60
- local_dir=LOCAL_MODEL_DIR,
61
- )
62
- return os.path.abspath(downloaded)
63
-
64
-
65
- def _load_model() -> tf.keras.Model:
66
- """Load the Keras model from the local path."""
67
- model_path = _ensure_model_local()
68
- logger.info(f"Loading Keras model from: {model_path}")
69
- model = tf.keras.models.load_model(model_path)
70
- logger.info("Model loaded")
71
- return model
72
 
73
 
74
- MODEL: tf.keras.Model = _load_model()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
 
76
 
77
  @app.get("/")
78
- def root():
79
- return {"ok": True, "classes": ANIMALS}
80
-
81
-
82
- @app.get("/health")
83
- def health():
84
- return {"status": "healthy", "model_loaded": MODEL is not None}
85
 
86
 
87
  @app.post("/upload/image")
88
  async def upload_image(img: UploadFile = File(...)):
 
 
 
89
  """
90
- Accept an image, resize to (64, 64), run prediction, and return label + scores.
91
- """
92
- if MODEL is None:
93
- raise HTTPException(status_code=503, detail="Model is not loaded")
94
 
95
  try:
96
- image = Image.open(img.file).convert("RGB").resize((64, 64))
97
- arr = np.array(image, dtype=np.float32)
98
- batch = np.expand_dims(arr, axis=0)
99
-
100
- probs = MODEL.predict(batch, verbose=0)[0]
101
- probs = np.asarray(probs, dtype=np.float32)
102
- idx = int(np.argmax(probs))
103
- label = ANIMALS[idx] if 0 <= idx < len(ANIMALS) else str(idx)
 
 
 
 
 
 
 
 
 
 
104
 
105
  return JSONResponse({"label": label, "scores": probs.tolist()})
106
- except Exception as e:
 
107
  logger.exception("Failed to process image")
108
- raise HTTPException(status_code=400, detail=f"Bad image: {e}")
109
 
110
 
111
  if __name__ == "__main__":
112
- # For local testing. In HF Spaces (Docker SDK), the container will run uvicorn with this app.
 
113
  import uvicorn
114
 
115
- port = int(os.environ.get("PORT", "7860"))
116
  uvicorn.run("main:app", host="0.0.0.0", port=port, reload=False)
 
1
+ from fastapi import FastAPI, File, UploadFile, HTTPException
2
+ from fastapi.middleware.cors import CORSMiddleware
3
+ from fastapi.responses import JSONResponse
4
+
5
  import os
6
  import logging
7
  from typing import List
 
10
  from PIL import Image
11
 
12
  import tensorflow as tf
 
 
 
 
13
  from huggingface_hub import hf_hub_download
14
 
15
 
 
 
 
16
  logging.basicConfig(level=logging.INFO)
17
+ logger = logging.getLogger(__name__)
 
 
 
 
 
 
 
18
 
19
 
20
+ app = FastAPI(title="Animal Classification API")
 
 
 
21
 
22
  app.add_middleware(
23
  CORSMiddleware,
24
  allow_origins=["*"],
25
+ allow_credentials=True,
26
  allow_methods=["*"],
27
  allow_headers=["*"],
 
28
  )
29
 
30
 
31
+ ANIMALS: List[str] = ["Cat", "Dog", "Panda"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
 
33
 
34
+ # Resolve model path. Prefer an environment variable for flexibility. As a fallback
35
+ # try a model.keras file in the repository root (one level up from this file).
36
+ base_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
37
+ model_path = os.environ.get("MODEL_PATH") or os.path.join(base_dir, "model.keras")
38
+
39
+
40
+ model = None
41
+ try:
42
+ if os.path.exists(model_path):
43
+ logger.info(f"Loading model from: {model_path}")
44
+ model = tf.keras.models.load_model(model_path)
45
+ logger.info("Model loaded successfully")
46
+ else:
47
+ # Try Hugging Face Hub fallback when local model is missing
48
+ repo_id = os.environ.get("HF_REPO_ID")
49
+ filename = os.environ.get("HF_MODEL_FILENAME", "model.keras")
50
+ revision = os.environ.get("HF_REVISION")
51
+ if repo_id:
52
+ try:
53
+ logger.info(
54
+ f"Local model not found. Trying HF Hub: repo_id={repo_id}, filename={filename}, revision={revision}"
55
+ )
56
+ local_dir = os.path.join(base_dir, "hf_model")
57
+ os.makedirs(local_dir, exist_ok=True)
58
+ downloaded_path = hf_hub_download(
59
+ repo_id=repo_id,
60
+ filename=filename,
61
+ repo_type="model",
62
+ revision=revision,
63
+ local_dir=local_dir,
64
+ )
65
+ logger.info(f"Downloaded model file to: {downloaded_path}")
66
+ model = tf.keras.models.load_model(downloaded_path)
67
+ logger.info("Model loaded successfully from HF Hub")
68
+ except Exception:
69
+ logger.exception("HF Hub fallback failed")
70
+ if model is None:
71
+ logger.warning(
72
+ "Model not available. Set MODEL_PATH or HF_REPO_ID (+ HF_MODEL_FILENAME) environment variables."
73
+ )
74
+ except Exception as e:
75
+ logger.exception("Failed to load model")
76
+ model = None
77
 
78
 
79
  @app.get("/")
80
+ def read_root():
81
+ return {"hello": "world"}
 
 
 
 
 
82
 
83
 
84
  @app.post("/upload/image")
85
  async def upload_image(img: UploadFile = File(...)):
86
+ """Accept an uploaded image, resize to (64,64), run model.predict and return the label.
87
+
88
+ If the model is not available the endpoint will return 503.
89
  """
90
+ if model is None:
91
+ raise HTTPException(status_code=503, detail="Model is not loaded on the server")
 
 
92
 
93
  try:
94
+ # Read image bytes and ensure RGB
95
+ original_image = Image.open(img.file).convert("RGB")
96
+ # Preprocess the image
97
+ original_image = original_image.resize((64, 64))
98
+ # Training used raw pixel values [0-255], NOT normalized to [0-1]
99
+ img_array = np.array(original_image, dtype=np.float32)
100
+ img_array = np.expand_dims(img_array, axis=0)
101
+
102
+ predictions = model.predict(img_array)
103
+ # predictions might be shape (1, N)
104
+ probs = np.asarray(predictions).squeeze()
105
+ if probs.ndim == 0:
106
+ # Model returned a single value
107
+ label_idx = int(np.round(probs))
108
+ else:
109
+ label_idx = int(np.argmax(probs))
110
+
111
+ label = ANIMALS[label_idx] if 0 <= label_idx < len(ANIMALS) else str(label_idx)
112
 
113
  return JSONResponse({"label": label, "scores": probs.tolist()})
114
+
115
+ except Exception:
116
  logger.exception("Failed to process image")
117
+ raise HTTPException(status_code=400, detail="Failed to process image")
118
 
119
 
120
  if __name__ == "__main__":
121
+ # Run with: python main.py
122
+ # Use Uvicorn as the ASGI server. MODEL_PATH and PORT can be overridden via env vars.
123
  import uvicorn
124
 
125
+ port = int(os.environ.get("PORT", 8000))
126
  uvicorn.run("main:app", host="0.0.0.0", port=port, reload=False)
requirements.txt CHANGED
@@ -1,7 +1,7 @@
1
  fastapi==0.116.1
2
  uvicorn[standard]==0.23.2
3
  Pillow==10.1.0
4
- numpy<2.0.0
5
  tensorflow-cpu==2.16.1
6
  huggingface_hub>=0.20.0
7
  python-multipart
 
1
  fastapi==0.116.1
2
  uvicorn[standard]==0.23.2
3
  Pillow==10.1.0
4
+ numpy==1.26.4
5
  tensorflow-cpu==2.16.1
6
  huggingface_hub>=0.20.0
7
  python-multipart