Wills17 commited on
Commit
613c8a3
·
verified ·
1 Parent(s): e8f5584
Files changed (1) hide show
  1. FastAPI_app.py +114 -83
FastAPI_app.py CHANGED
@@ -3,72 +3,142 @@
3
  # import libraries
4
  import os
5
  import io
6
- import numpy as np
7
- import traceback
8
  import time
9
- import uvicorn
10
-
11
 
12
- # Heavy imports
13
- import tensorflow as tf
14
  from PIL import Image
15
  from fastapi import FastAPI, Form, UploadFile, File, Request, HTTPException
16
  from fastapi.responses import HTMLResponse
17
  from fastapi.staticfiles import StaticFiles
18
  from fastapi.templating import Jinja2Templates
19
  from fastapi.middleware.cors import CORSMiddleware
 
 
 
20
  import google.generativeai as genai
21
 
 
 
 
 
 
22
 
 
 
 
23
 
24
- from transformers import AutoTokenizer, AutoModelForCausalLM
 
25
 
26
- model_name = "gpt2"
27
- save_path = "./models/gpt2"
28
 
29
- tokenizer = AutoTokenizer.from_pretrained(model_name)
30
- tokenizer.save_pretrained(save_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
 
32
- model = AutoModelForCausalLM.from_pretrained(model_name)
33
- model.save_pretrained(save_path)
 
34
 
 
 
 
 
 
 
 
 
 
35
 
36
- # Load model (global) once startup.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
  MODEL_PATH = "models/ingredient_model.h5"
38
  MODEL = tf.keras.models.load_model(MODEL_PATH)
39
 
40
  # Class names
41
- # CLASS_NAMES = sorted(os.listdir("dataset/dataset_2/train"))
42
- CLASS_NAMES = [
43
- 'apple', 'banana', 'beetroot', 'bell pepper', 'cabbage', 'capsicum', 'carrot', 'cauliflower', 'chilli pepper',
44
- 'corn', 'cucumber', 'eggplant', 'garlic', 'ginger', 'grapes', 'jalepeno', 'kiwi', 'lemon', 'lettuce', 'mango',
45
- 'onion', 'orange', 'paprika', 'pear', 'peas', 'pineapple', 'pomegranate', 'potato', 'raddish', 'soy beans',
46
- 'spinach', 'sweetcorn', 'sweetpotato', 'tomato', 'turnip', 'watermelon']
 
47
 
48
  # Infer uploaded image function
49
  def infer_image(pil_image):
50
  img = pil_image.resize((224, 224))
51
  IMG = np.expand_dims(np.array(img) / 255.0, axis=0)
52
-
53
- preds = MODEL.predict(IMG)[0]
54
 
 
55
  top_idxs = np.argsort(preds)[::-1][:3]
56
 
57
  # ingredient list
58
  ingredients = []
59
-
60
  for i in top_idxs:
61
- confidence = float(preds[i])
62
-
63
  ingredients.append({
64
  "name": CLASS_NAMES[i].capitalize(),
65
- "confidence": confidence
66
  })
67
 
68
  # Limit to top 5 ingredients
69
  if len(ingredients) >= 5:
70
  break
71
-
72
  # incase of no prediction.
73
  if not ingredients:
74
  return [{"name": "unknown", "confidence": 0.0}]
@@ -76,27 +146,6 @@ def infer_image(pil_image):
76
  return ingredients
77
 
78
 
79
- # Fallback Recipe Generator -> GPT-2
80
- def generate_recipe_local(ingredient_names):
81
- ingredients = ", ".join(ingredient_names)
82
- return f"""
83
- # Simple Local Fallback Recipe
84
-
85
- Since no API key was provided, here is a simple offline recipe.
86
-
87
- ## Ingredients
88
- - {ingredients}
89
-
90
- ## Steps
91
- 1. Combine {ingredients} in a bowl.
92
- 2. Add salt and seasoning as desired.
93
- 3. Cook for 10 minutes.
94
- 4. Serve warm.
95
-
96
- *(Generated locally without external AI models.)*
97
- """.strip()
98
-
99
-
100
  # initialize FastAPI app
101
  app = FastAPI(
102
  title="Fridge2Dish API",
@@ -118,6 +167,7 @@ app.add_middleware(
118
  )
119
 
120
 
 
121
  # ROUTES
122
 
123
  # Home Route
@@ -134,7 +184,6 @@ async def upload_image(
134
  ):
135
 
136
  try:
137
- # check image file
138
  if not file.filename.lower().endswith((".jpg", ".jpeg", ".png")):
139
  raise HTTPException(status_code=400, detail="Invalid image format.")
140
 
@@ -148,24 +197,17 @@ async def upload_image(
148
  ingredients = infer_image(pil_img)
149
  end_time = time.time()
150
 
151
- print(f"Ingredient detection took {end_time - start_time:.2f} seconds")
152
 
153
- print(f"Detected ingredients: {ingredients}")
154
 
155
- if not ingredients:
156
- return {"ingredients": [],
157
- "recipe": "No ingredients detected, Try to take a clearer picture."}
158
-
159
-
160
  ingredient_names = [item["name"] for item in ingredients]
161
-
162
-
163
- # Recipe generation using Gemini
164
- # Get api key from user input
165
  api_key = user_api_key.strip()
166
-
167
  if api_key:
168
-
169
  try:
170
  # Try Gemini first...
171
  genai.configure(api_key=api_key)
@@ -179,36 +221,25 @@ async def upload_image(
179
  - Ingredients list with quantities
180
  - 6-10 concise steps
181
  - Optional fun tips or variations
182
- Make it easy to follow and appetizing!
183
-
184
- Do not include any lines like "Sure! Here's a recipe...", "Here's a simple..." or similar.
185
  Return results in markdown format.
186
  """
187
-
188
- print("\n\nTrying Gemini...")
189
  response = model.generate_content(prompt)
190
  recipe_text = response.text.strip()
191
- # print(recipe_text)
192
-
193
 
194
  except Exception as e1:
195
- print("\nGemini failed. Falling to GPT-2:", e1)
196
  recipe_text = generate_recipe_local(ingredient_names)
197
-
198
  else:
199
- # Since no api_key -> fallback to GPT-2
200
- print("\n\nNo API key provided. Using GPT-2 to generate recipe")
201
  recipe_text = generate_recipe_local(ingredient_names)
202
-
203
- # results
204
- return {
205
- "ingredients": ingredients,
206
- "recipe": recipe_text,
207
- }
208
-
209
- except Exception as e2:
210
  traceback.print_exc()
211
- raise HTTPException(status_code=500, detail=f"Server Error: {str(e2)}")
212
 
213
 
214
  # Health check
@@ -219,4 +250,4 @@ def health():
219
 
220
  # Run app
221
  if __name__ == "__main__":
222
- uvicorn.run(app, host="0.0.0.0", port=7860)
 
3
  # import libraries
4
  import os
5
  import io
 
 
6
  import time
7
+ import traceback
 
8
 
9
+ import uvicorn
10
+ import numpy as np
11
  from PIL import Image
12
  from fastapi import FastAPI, Form, UploadFile, File, Request, HTTPException
13
  from fastapi.responses import HTMLResponse
14
  from fastapi.staticfiles import StaticFiles
15
  from fastapi.templating import Jinja2Templates
16
  from fastapi.middleware.cors import CORSMiddleware
17
+
18
+ # import ML libraries
19
+ import tensorflow as tf
20
  import google.generativeai as genai
21
 
22
+ # Transformers libraries (for fallback)
23
+ from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
24
+ import torch
25
+ import threading
26
+
27
 
28
+ # create presistent storage for GPT-2
29
+ LOCAL_GPT2_DIR = "/data/gpt2" # HF Spaces persistent folder
30
+ REMOTE_GPT2_NAME = "gpt2-medium"
31
 
32
+ _local_generator = None
33
+ _local_lock = threading.Lock()
34
 
 
 
35
 
36
+ def load_or_download_gpt2():
37
+ """
38
+ This function downloads GPT-2-medium into `/data/gpt2` on first run.
39
+ And on subsequent runs, it loads the saved local version.
40
+ """
41
+
42
+ global _local_generator
43
+ if _local_generator is not None:
44
+ return _local_generator
45
+
46
+ with _local_lock:
47
+ if _local_generator is not None:
48
+ return _local_generator
49
+
50
+ # Ensure /data directory exists
51
+ os.makedirs(LOCAL_GPT2_DIR, exist_ok=True)
52
+
53
+ # Case 1: if local model already exists, load from storage...
54
+ if os.path.exists(LOCAL_GPT2_DIR) and os.listdir(LOCAL_GPT2_DIR):
55
+ print("\n🔵 Loading GPT-2 from /data/gpt2 (local cache)...")
56
+ tokenizer = AutoTokenizer.from_pretrained(LOCAL_GPT2_DIR)
57
+ model = AutoModelForCausalLM.from_pretrained(LOCAL_GPT2_DIR)
58
+
59
+ else:
60
+ # Case 2: try download, and save
61
+ print("\n🟡 Downloading GPT-2-medium... (first run)")
62
+ tokenizer = AutoTokenizer.from_pretrained(REMOTE_GPT2_NAME)
63
+ model = AutoModelForCausalLM.from_pretrained(REMOTE_GPT2_NAME)
64
 
65
+ print("\n🟢 Saving GPT-2-medium to /data/gpt2...")
66
+ tokenizer.save_pretrained(LOCAL_GPT2_DIR)
67
+ model.save_pretrained(LOCAL_GPT2_DIR)
68
 
69
+ device = 0 if torch.cuda.is_available() else -1
70
+ _local_generator = pipeline(
71
+ "text-generation",
72
+ model=model,
73
+ tokenizer=tokenizer,
74
+ device=device,
75
+ )
76
+ print("\n\n✅ GPT-2 loaded and ready.")
77
+ return _local_generator
78
 
79
+
80
+
81
+ def generate_recipe_local(ingredient_names, max_new_tokens=300, temperature=0.9):
82
+ """
83
+ Local offline recipe generation with GPT-2-medium.
84
+ """
85
+ generate = load_or_download_gpt2()
86
+
87
+ prompt = (
88
+ f"You are an AI chef. Create a short recipe using only these ingredients: "
89
+ f"{', '.join(ingredient_names)}.\n\n"
90
+ "- Start with a recipe name on its own line.\n"
91
+ "- Then a one-sentence description.\n"
92
+ "- Then a bullet list of ingredients with approximate quantities.\n"
93
+ "- Then 6-10 concise numbered steps.\n"
94
+ "- Optionally one quick tip.\n\nRecipe:\n"
95
+ )
96
+
97
+ outputs = generate(prompt, do_sample=True, temperature=temperature,
98
+ max_new_tokens=max_new_tokens, num_return_sequences=1)
99
+
100
+ recipe_text = outputs[0]["generated_text"]
101
+
102
+ if "Recipe:" in recipe_text:
103
+ recipe_text = recipe_text.split("Recipe:", 1)[1].strip()
104
+
105
+ return recipe_text.strip()
106
+
107
+
108
+
109
+ # Load ingredients model once startup.
110
  MODEL_PATH = "models/ingredient_model.h5"
111
  MODEL = tf.keras.models.load_model(MODEL_PATH)
112
 
113
  # Class names
114
+ CLASS_NAMES = sorted(os.listdir("dataset/dataset_2/train"))
115
+ print(CLASS_NAMES)
116
+ # CLASS_NAMES = [
117
+ # 'apple', 'banana', 'beetroot', 'bell pepper', 'cabbage', 'capsicum', 'carrot', 'cauliflower', 'chilli pepper',
118
+ # 'corn', 'cucumber', 'eggplant', 'garlic', 'ginger', 'grapes', 'jalepeno', 'kiwi', 'lemon', 'lettuce', 'mango',
119
+ # 'onion', 'orange', 'paprika', 'pear', 'peas', 'pineapple', 'pomegranate', 'potato', 'raddish', 'soy beans',
120
+ # 'spinach', 'sweetcorn', 'sweetpotato', 'tomato', 'turnip', 'watermelon']
121
 
122
  # Infer uploaded image function
123
  def infer_image(pil_image):
124
  img = pil_image.resize((224, 224))
125
  IMG = np.expand_dims(np.array(img) / 255.0, axis=0)
 
 
126
 
127
+ preds = MODEL.predict(IMG)[0]
128
  top_idxs = np.argsort(preds)[::-1][:3]
129
 
130
  # ingredient list
131
  ingredients = []
 
132
  for i in top_idxs:
 
 
133
  ingredients.append({
134
  "name": CLASS_NAMES[i].capitalize(),
135
+ "confidence": float(preds[i])
136
  })
137
 
138
  # Limit to top 5 ingredients
139
  if len(ingredients) >= 5:
140
  break
141
+
142
  # incase of no prediction.
143
  if not ingredients:
144
  return [{"name": "unknown", "confidence": 0.0}]
 
146
  return ingredients
147
 
148
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
149
  # initialize FastAPI app
150
  app = FastAPI(
151
  title="Fridge2Dish API",
 
167
  )
168
 
169
 
170
+
171
  # ROUTES
172
 
173
  # Home Route
 
184
  ):
185
 
186
  try:
 
187
  if not file.filename.lower().endswith((".jpg", ".jpeg", ".png")):
188
  raise HTTPException(status_code=400, detail="Invalid image format.")
189
 
 
197
  ingredients = infer_image(pil_img)
198
  end_time = time.time()
199
 
200
+ print(f"\nIngredient detection took {end_time - start_time:.2f} seconds")
201
 
202
+ print(f"\nDetected ingredients: {ingredients}")
203
 
 
 
 
 
 
204
  ingredient_names = [item["name"] for item in ingredients]
205
+
206
+
207
+ # Recipe generation using Gemini or GPT-2 fallback
 
208
  api_key = user_api_key.strip()
209
+
210
  if api_key:
 
211
  try:
212
  # Try Gemini first...
213
  genai.configure(api_key=api_key)
 
221
  - Ingredients list with quantities
222
  - 6-10 concise steps
223
  - Optional fun tips or variations
 
 
 
224
  Return results in markdown format.
225
  """
226
+
 
227
  response = model.generate_content(prompt)
228
  recipe_text = response.text.strip()
 
 
229
 
230
  except Exception as e1:
231
+ print("\nGemini failed switching to GPT-2:", e1)
232
  recipe_text = generate_recipe_local(ingredient_names)
233
+
234
  else:
235
+ print("\nNo API key provided using GPT-2 fallback.")
 
236
  recipe_text = generate_recipe_local(ingredient_names)
237
+
238
+ return {"ingredients": ingredients, "recipe": recipe_text}
239
+
240
+ except Exception as e:
 
 
 
 
241
  traceback.print_exc()
242
+ raise HTTPException(status_code=500, detail=f"Server Error: {str(e)}")
243
 
244
 
245
  # Health check
 
250
 
251
  # Run app
252
  if __name__ == "__main__":
253
+ uvicorn.run(app, host="0.0.0.0", port=7860)