JackRabbit commited on
Commit
c442c56
·
1 Parent(s): e9471fb

new fastapi app

Browse files
Files changed (2) hide show
  1. app.py +202 -103
  2. requirements.txt +5 -4
app.py CHANGED
@@ -1,143 +1,242 @@
1
- import gradio as gr
2
- import pandas as pd
3
- import requests
4
- from PIL import Image
5
  import io
 
 
 
6
  import os
7
-
8
- from huggingface_hub import snapshot_download
 
9
  from autogluon.multimodal import MultiModalPredictor
 
10
 
 
 
 
 
 
 
 
 
 
11
 
12
- ###########################
13
- # Load Model from HF Hub
14
- ###########################
15
  def load_model():
16
- """Download your model files from the Hub and load with AutoGluon."""
 
 
 
17
  repo_id = "Honey-Bee-Society/honeybee_ml_v1"
18
  local_dir = snapshot_download(repo_id)
 
 
 
 
 
 
 
19
  predictor = MultiModalPredictor.load(local_dir)
20
  return predictor
21
 
22
- predictor = load_model() # Load once at startup
23
-
24
-
25
- ###########################
26
- # Utility Functions
27
- ###########################
28
- def resize_image_if_large(image, max_size_mb=1):
29
- """Resizes the image if it is larger than `max_size_mb` MB."""
30
- img_bytes = io.BytesIO()
31
- image.save(img_bytes, format='PNG')
32
- size_mb = len(img_bytes.getvalue()) / (1024 * 1024)
33
- if size_mb > max_size_mb:
34
- scale_factor = (max_size_mb / size_mb) ** 0.5
35
- new_w = int(image.width * scale_factor)
36
- new_h = int(image.height * scale_factor)
37
- image = image.resize((new_w, new_h))
38
- return image
39
-
40
- def classify_image(image: Image.Image):
41
  """
42
- Given a PIL Image, predict Honey Bee vs. Bumblebee vs. Vespidae.
43
- Returns a dictionary of probabilities + predicted label.
44
  """
45
- # Optionally resize large images
46
- image = resize_image_if_large(image, 1)
 
47
 
48
- # Convert to bytes for the predictor
49
- buf = io.BytesIO()
50
- image.save(buf, format='PNG')
51
- df = pd.DataFrame({"image": [buf.getvalue()]})
 
 
 
52
 
 
 
 
 
 
 
 
 
 
 
53
  probabilities = predictor.predict_proba(df, realtime=True)
 
54
 
55
- # Adjust indices if your model’s labeling is different
 
 
 
 
 
56
  honeybee_score = float(probabilities[1].iloc[0]) * 100
57
  bumblebee_score = float(probabilities[2].iloc[0]) * 100
58
  vespidae_score = float(probabilities[3].iloc[0]) * 100
59
 
60
  highest_score = max(honeybee_score, bumblebee_score, vespidae_score)
61
  if highest_score < 80:
62
- label = "No bee detected (scores too low)."
63
  else:
64
  if honeybee_score == highest_score:
65
- label = "Honey Bee"
66
  elif bumblebee_score == highest_score:
67
- label = "Bumblebee"
68
  else:
69
- label = "Vespidae (wasp/hornet)"
70
 
71
  return {
72
  "honeybee_score": honeybee_score,
73
  "bumblebee_score": bumblebee_score,
74
  "vespidae_score": vespidae_score,
75
- "prediction_label": label
76
  }
77
 
78
 
79
- ###########################
80
- # The Main Predict Function
81
- ###########################
82
- def predict(uploaded_image_dict, fallback_url):
83
  """
84
- 1) If `uploaded_image_dict["path"]` looks like a URL (http/https),
85
- fetch the image from the internet.
86
- 2) Otherwise, Gradio will have downloaded it to a local file path
87
- and `uploaded_image_dict["name"]` is that path. (or "path" is local)
88
- 3) If the user provides nothing in the first input, try `fallback_url`.
89
  """
90
- # Case 1: We have an "uploaded_image_dict" from the first input
91
- if uploaded_image_dict and "path" in uploaded_image_dict:
92
- path_val = uploaded_image_dict["path"]
93
- if path_val.startswith("http"):
94
- # It's a remote URL
95
- try:
96
- resp = requests.get(path_val.strip(), timeout=10)
97
- resp.raise_for_status()
98
- image = Image.open(io.BytesIO(resp.content))
99
- except Exception as e:
100
- return {"error": f"Failed to download/open image from URL: {e}"}
101
- else:
102
- # It's likely a local file path Gradio prepared
103
- try:
104
- image = Image.open(path_val)
105
- except Exception as e:
106
- return {"error": f"Failed to open local file: {e}"}
107
- # Case 2: No image dict, but maybe we have a fallback URL
108
- elif fallback_url.strip():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
  try:
110
- resp = requests.get(fallback_url.strip(), timeout=10)
111
- resp.raise_for_status()
112
- image = Image.open(io.BytesIO(resp.content))
113
  except Exception as e:
114
- return {"error": f"Failed to download/open fallback URL: {e}"}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
  else:
116
- return {"error": "No image URL or fallback URL provided."}
117
-
118
- # Classify the resulting image
119
- return classify_image(image)
120
-
121
-
122
- ###########################
123
- # Gradio Interface
124
- ###########################
125
- # 1) First input = an Image (filepath)
126
- # 2) Second input = a Textbox
127
- # This order means data[0] is the dictionary for the image,
128
- # data[1] is the string for fallback_url.
129
-
130
- demo = gr.Interface(
131
- fn=predict,
132
- inputs=[
133
- gr.Image(type="filepath", label="Public Image or Uploaded File"),
134
- gr.Textbox(label="Fallback URL (optional)")
135
- ],
136
- outputs="json",
137
- api_name="predict",
138
- title="Honey Bee Image Classification",
139
- description="Provide an image dict with 'path' or a fallback URL. We'll classify the bee!"
140
- )
141
-
142
- if __name__ == "__main__":
143
- demo.launch()
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, UploadFile, File, HTTPException
2
+ from pydantic import BaseModel
3
+ import uvicorn
 
4
  import io
5
+ import logging
6
+ import datetime
7
+ import re
8
  import os
9
+ import requests
10
+ import pandas as pd
11
+ from PIL import Image
12
  from autogluon.multimodal import MultiModalPredictor
13
+ from huggingface_hub import snapshot_download
14
 
15
+ ###############################################################################
16
+ # Logging configuration (optional)
17
+ ###############################################################################
18
+ log_filename = "model_predictions.log"
19
+ logging.basicConfig(
20
+ filename=log_filename,
21
+ level=logging.INFO,
22
+ format='%(asctime)s - %(message)s'
23
+ )
24
 
25
+ ###############################################################################
26
+ # Model loading
27
+ ###############################################################################
28
  def load_model():
29
+ """
30
+ Downloads the model from the specified huggingface hub repo and
31
+ loads it using MultiModalPredictor.
32
+ """
33
  repo_id = "Honey-Bee-Society/honeybee_ml_v1"
34
  local_dir = snapshot_download(repo_id)
35
+
36
+ assets_path = os.path.join(local_dir, "assets.json")
37
+ model_checkpoint = os.path.join(local_dir, "model.ckpt")
38
+
39
+ if not os.path.exists(assets_path) or not os.path.exists(model_checkpoint):
40
+ raise FileNotFoundError("Required model files not found in the downloaded directory.")
41
+
42
  predictor = MultiModalPredictor.load(local_dir)
43
  return predictor
44
 
45
+ ###############################################################################
46
+ # Image processing and prediction routines
47
+ ###############################################################################
48
+ def resize_image_proportionally(image, max_size_mb=1):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
  """
50
+ If the in-memory size of the image is > max_size_mb,
51
+ resize it proportionally.
52
  """
53
+ img_byte_array = io.BytesIO()
54
+ image.save(img_byte_array, format='PNG')
55
+ img_size = len(img_byte_array.getvalue()) / (1024 * 1024)
56
 
57
+ if img_size > max_size_mb:
58
+ scale_factor = (max_size_mb / img_size) ** 0.5
59
+ new_width = int(image.width * scale_factor)
60
+ new_height = int(image.height * scale_factor)
61
+ image = image.resize((new_width, new_height))
62
+
63
+ return image
64
 
65
+
66
+ def predict_image(image: Image.Image, predictor: MultiModalPredictor):
67
+ """
68
+ Run the prediction via the AutoGluon MultiModalPredictor.
69
+ Returns probability dataframe for each class.
70
+ """
71
+ img_byte_array = io.BytesIO()
72
+ image.save(img_byte_array, format='PNG')
73
+ img_data = img_byte_array.getvalue()
74
+ df = pd.DataFrame({"image": [img_data]})
75
  probabilities = predictor.predict_proba(df, realtime=True)
76
+ return probabilities
77
 
78
+
79
+ def determine_label(probabilities):
80
+ """
81
+ Given the probabilities DataFrame, compute the final label.
82
+ Returns a dict with numeric scores and a text label.
83
+ """
84
  honeybee_score = float(probabilities[1].iloc[0]) * 100
85
  bumblebee_score = float(probabilities[2].iloc[0]) * 100
86
  vespidae_score = float(probabilities[3].iloc[0]) * 100
87
 
88
  highest_score = max(honeybee_score, bumblebee_score, vespidae_score)
89
  if highest_score < 80:
90
+ prediction_label = "No bee detected (scores too low)."
91
  else:
92
  if honeybee_score == highest_score:
93
+ prediction_label = "Honey Bee"
94
  elif bumblebee_score == highest_score:
95
+ prediction_label = "Bumblebee"
96
  else:
97
+ prediction_label = "Vespidae (wasp/hornet)"
98
 
99
  return {
100
  "honeybee_score": honeybee_score,
101
  "bumblebee_score": bumblebee_score,
102
  "vespidae_score": vespidae_score,
103
+ "prediction_label": prediction_label
104
  }
105
 
106
 
107
+ def log_predictions(honeybee_score, bumblebee_score, vespidae_score, source_info):
 
 
 
108
  """
109
+ Log predictions to a file (optional).
 
 
 
 
110
  """
111
+ logging.info(
112
+ f"Source: {source_info}, "
113
+ f"Honeybee: {honeybee_score:.2f}%, "
114
+ f"Bumblebee: {bumblebee_score:.2f}%, "
115
+ f"Vespidae: {vespidae_score:.2f}%"
116
+ )
117
+
118
+ ###############################################################################
119
+ # Request models
120
+ ###############################################################################
121
+ class ImageUrlRequest(BaseModel):
122
+ image_url: str
123
+
124
+ ###############################################################################
125
+ # FastAPI app and endpoints
126
+ ###############################################################################
127
+ app = FastAPI(title="Honey Bee Classification API")
128
+
129
+ # Load the model at startup (only once).
130
+ predictor = load_model()
131
+
132
+ @app.get("/ping")
133
+ def ping():
134
+ """
135
+ A simple endpoint to check if the API is running.
136
+ """
137
+ return {"message": "pong"}
138
+
139
+ @app.post("/predict")
140
+ async def predict_endpoint(
141
+ image_url_req: ImageUrlRequest = None,
142
+ file: UploadFile = File(None)
143
+ ):
144
+ """
145
+ Accepts either a JSON body with `image_url` or a multipart form-data `file`.
146
+ Returns JSON with honeybee, bumblebee, vespidae scores, and a predicted label.
147
+ """
148
+ # 1) If user provided an image URL
149
+ if image_url_req and image_url_req.image_url:
150
+ image_url = image_url_req.image_url
151
+ # Download the image
152
+ try:
153
+ response = requests.get(
154
+ image_url,
155
+ headers={"User-Agent": "HoneyBeeClassification/1.0 (+https://example.com)"}
156
+ )
157
+ if response.status_code != 200:
158
+ raise HTTPException(
159
+ status_code=400,
160
+ detail=f"Failed to retrieve image from {image_url}. HTTP {response.status_code}"
161
+ )
162
+ except Exception as e:
163
+ raise HTTPException(
164
+ status_code=400,
165
+ detail=f"Error downloading image from {image_url}: {e}"
166
+ )
167
+
168
+ image_bytes = response.content
169
+ image_size_mb = len(image_bytes) / (1024*1024)
170
+ if image_size_mb > 10:
171
+ raise HTTPException(
172
+ status_code=413,
173
+ detail=f"Image size {image_size_mb:.2f}MB exceeds 10MB limit."
174
+ )
175
+ # Convert to PIL Image
176
  try:
177
+ image = Image.open(io.BytesIO(image_bytes))
 
 
178
  except Exception as e:
179
+ raise HTTPException(
180
+ status_code=400,
181
+ detail=f"Could not open image: {e}"
182
+ )
183
+
184
+ # 2) If user instead provided a file
185
+ elif file is not None:
186
+ # Check file size
187
+ file_size = 0
188
+ file.file.seek(0, 2) # move to end
189
+ file_size = file.file.tell()
190
+ file.file.seek(0) # reset pointer
191
+ mb_size = file_size / (1024 * 1024)
192
+ if mb_size > 10:
193
+ raise HTTPException(
194
+ status_code=413,
195
+ detail=f"Uploaded file size {mb_size:.2f}MB exceeds 10MB limit."
196
+ )
197
+
198
+ # Convert to PIL Image
199
+ try:
200
+ contents = await file.read()
201
+ image = Image.open(io.BytesIO(contents))
202
+ except Exception as e:
203
+ raise HTTPException(
204
+ status_code=400,
205
+ detail=f"Could not open uploaded image: {e}"
206
+ )
207
+ source_info = f"uploaded_file:{file.filename}"
208
  else:
209
+ raise HTTPException(
210
+ status_code=400,
211
+ detail="No image provided. Supply either `image_url` or `file`."
212
+ )
213
+
214
+ # Resize the image if needed
215
+ image = resize_image_proportionally(image)
216
+
217
+ # Predict
218
+ try:
219
+ probabilities = predict_image(image, predictor)
220
+ results = determine_label(probabilities)
221
+ except Exception as e:
222
+ raise HTTPException(
223
+ status_code=500,
224
+ detail=f"Prediction failed: {e}"
225
+ )
226
+
227
+ # Optionally log predictions
228
+ source_name = image_url_req.image_url if (image_url_req and image_url_req.image_url) else file.filename
229
+ log_predictions(
230
+ results["honeybee_score"],
231
+ results["bumblebee_score"],
232
+ results["vespidae_score"],
233
+ source_info=source_name
234
+ )
235
+
236
+ return results
237
+
238
+
239
+ # If running locally, uncomment to start the server via `python app.py`
240
+ # (On Hugging Face Spaces, a separate command may be used.)
241
+ # if __name__ == "__main__":
242
+ # uvicorn.run(app, host="0.0.0.0", port=7860)
requirements.txt CHANGED
@@ -1,6 +1,7 @@
1
- Pillow
 
 
 
2
  pandas
3
  autogluon.multimodal
4
- huggingface_hub
5
- requests
6
- gradio
 
1
+ fastapi
2
+ uvicorn
3
+ pillow
4
+ requests
5
  pandas
6
  autogluon.multimodal
7
+ huggingface_hub