Wills17 commited on
Commit
ec179db
·
verified ·
1 Parent(s): 043e2e5

Update FastAPI_app.py

Browse files
Files changed (1) hide show
  1. FastAPI_app.py +168 -143
FastAPI_app.py CHANGED
@@ -1,9 +1,11 @@
1
  # FastAPI application for Fridge2Dish
2
 
 
3
  import os
4
  import io
5
  import time
6
  import traceback
 
7
 
8
  import uvicorn
9
  import numpy as np
@@ -18,186 +20,195 @@ from fastapi.middleware.cors import CORSMiddleware
18
  import tensorflow as tf
19
  import google.generativeai as genai
20
 
21
- # Transformers libraries (for fallback)
22
  from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
23
  import torch
24
- import threading
25
 
26
 
27
- # create presistent storage for GPT-2
28
- LOCAL_GPT2_DIR = "/data/gpt2" # HF Spaces persistent folder
29
- REMOTE_GPT2_NAME = "gpt2-medium"
 
 
 
30
 
31
- _local_generator = None
32
- _local_lock = threading.Lock()
 
33
 
34
 
35
- def load_or_download_gpt2():
 
36
  """
37
- This function downloads GPT-2-medium into `/data/gpt2` on first run.
38
- And on subsequent runs, it loads the saved local version.
 
39
  """
 
 
 
40
 
41
- global _local_generator
42
- if _local_generator is not None:
43
- return _local_generator
44
 
45
- with _local_lock:
46
- if _local_generator is not None:
47
- return _local_generator
48
-
49
- os.makedirs(LOCAL_GPT2_DIR, exist_ok=True)
50
-
51
- # Load from cache
52
- if os.listdir(LOCAL_GPT2_DIR):
53
- print("\n🔵 Loading GPT-2 from local cache...")
54
- tokenizer = AutoTokenizer.from_pretrained(LOCAL_GPT2_DIR)
55
- model = AutoModelForCausalLM.from_pretrained(LOCAL_GPT2_DIR)
56
 
 
 
 
 
 
57
  else:
58
- # First-time download
59
- print("\n🟡 Downloading GPT-2-medium...")
60
- tokenizer = AutoTokenizer.from_pretrained(REMOTE_GPT2_NAME)
61
- model = AutoModelForCausalLM.from_pretrained(REMOTE_GPT2_NAME)
62
-
63
- print("\n🟢 Saving GPT-2-medium to persistent storage...")
64
- tokenizer.save_pretrained(LOCAL_GPT2_DIR)
65
- model.save_pretrained(LOCAL_GPT2_DIR)
66
-
67
  device = 0 if torch.cuda.is_available() else -1
68
- _local_generator = pipeline(
 
69
  "text-generation",
70
  model=model,
71
  tokenizer=tokenizer,
72
  device=device,
 
 
 
 
 
73
  )
 
 
74
 
75
- print("\n\n✅ GPT-2 ready for generation.")
76
- return _local_generator
77
 
78
 
79
- # improve GPT-2 recipe generation
80
- def clean_output(text: str) -> str:
81
  """
82
- Remove garbage, repeated sentences, disclaimers, and anything before the recipe.
 
 
 
83
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
 
85
- # Remove leading garbage before a recognizable title
86
- for key in ["Ingredients", "Recipe", "###", "Steps"]:
87
- if key in text:
88
- text = text.split(key, 1)[1]
89
- text = key + text
90
- break
91
-
92
- # Remove repeated lines
93
- cleaned = []
94
- seen = set()
95
-
96
- for line in text.split("\n"):
97
- l = line.strip()
98
- if l not in seen:
99
- seen.add(l)
100
- cleaned.append(line)
101
-
102
- return "\n".join(cleaned).strip()
103
-
104
 
105
 
106
- def generate_recipe_local(ingredient_names):
 
107
  """
108
- Improved GPT-2 recipe generation with strict formatting.
109
  """
110
- generator = load_or_download_gpt2()
111
-
112
- prompt = f"""
113
- You are an AI chef. Create a clean, structured recipe using ONLY these ingredients:
114
- {', '.join(ingredient_names)}.
115
-
116
- STRICT RULES:
117
- 1. Start with a short recipe title (one line).
118
- 2. Then one-sentence description.
119
- 3. Then a section titled "### Ingredients" with bullet points.
120
- 4. Use quantities (approximate is okay).
121
- 5. Then "### Steps" with 6–10 numbered steps.
122
- 6. Keep it short, clear, and well formatted.
123
- 7. No rambling. No repeating. No intros. No disclaimers.
124
- 8. End after the steps.
125
-
126
- FORMAT EXAMPLE:
127
-
128
- Title
129
- Short description.
130
-
131
- ### Ingredients
132
- - item
133
- - item
134
-
135
- ### Steps
136
- 1. step
137
- 2. step
138
- 3. step
139
-
140
- Generate the recipe:
141
- """
142
-
143
- output = generator(
144
- prompt,
145
- max_new_tokens=180,
146
- temperature=0.7,
147
- do_sample=True,
148
- top_p=0.95,
149
- num_return_sequences=1
150
- )[0]["generated_text"]
151
-
152
- cleaned = clean_output(output)
153
-
154
  return cleaned
155
 
156
 
157
 
158
- # Load ingredients model once startup.
159
- MODEL_PATH = "models/ingredient_model.h5"
160
  MODEL = tf.keras.models.load_model(MODEL_PATH)
161
 
162
- CLASS_NAMES = [
163
- 'apple', 'banana', 'beetroot', 'bell pepper', 'cabbage', 'capsicum', 'carrot', 'cauliflower', 'chilli pepper',
164
- 'corn', 'cucumber', 'eggplant', 'garlic', 'ginger', 'grapes', 'jalepeno', 'kiwi', 'lemon', 'lettuce', 'mango',
165
- 'onion', 'orange', 'paprika', 'pear', 'peas', 'pineapple', 'pomegranate', 'potato', 'raddish', 'soy beans',
166
- 'spinach', 'sweetcorn', 'sweetpotato', 'tomato', 'turnip', 'watermelon']
 
 
 
 
 
 
 
167
 
168
  # Infer uploaded image function
169
  def infer_image(pil_image):
 
 
 
170
  img = pil_image.resize((224, 224))
171
  arr = np.expand_dims(np.array(img) / 255.0, axis=0)
172
-
173
  preds = MODEL.predict(arr)[0]
 
174
  top_idxs = np.argsort(preds)[::-1][:3]
175
-
176
- ingredients = [
177
- {"name": CLASS_NAMES[i].capitalize(), "confidence": float(preds[i])}
178
- for i in top_idxs
179
- ]
180
-
181
- return ingredients or [{"name": "Unknown", "confidence": 0.0}]
182
-
183
 
184
 
185
  # initialize FastAPI app
186
  app = FastAPI(
187
  title="Fridge2Dish",
188
- description="Upload image → Detect ingredients → Generate recipes",
189
- version="2.0.0"
190
  )
191
 
192
- # Serve static files
193
  app.mount("/static", StaticFiles(directory="static"), name="static")
194
  templates = Jinja2Templates(directory="templates")
195
 
196
  # CORS
197
  app.add_middleware(
198
  CORSMiddleware,
199
- allow_origins=["*"], allow_credentials=True,
200
- allow_methods=["*"], allow_headers=["*"]
 
 
201
  )
202
 
203
 
@@ -215,24 +226,32 @@ def home(request: Request):
215
  async def upload_image(
216
  file: UploadFile = File(...),
217
  user_api_key: str = Form(alias="api_key", default="")
218
- ):
219
-
220
  try:
221
  if not file.filename.lower().endswith((".jpg", ".jpeg", ".png")):
222
  raise HTTPException(status_code=400, detail="Invalid image format.")
223
-
224
-
225
- image_bytes = await file.read()
226
- pil_img = Image.open(io.BytesIO(image_bytes)).convert("RGB")
227
-
228
- # Load image
 
229
  ingredients = infer_image(pil_img)
230
- ingredient_names = [i["name"] for i in ingredients]
 
231
 
232
- # Try Gemini if user provided a key
233
- if user_api_key.strip():
 
 
 
 
 
234
  try:
235
- genai.configure(api_key=user_api_key.strip())
 
236
  model = genai.GenerativeModel("gemini-2.5-flash")
237
 
238
  prompt = f"""
@@ -248,17 +267,23 @@ async def upload_image(
248
 
249
  response = model.generate_content(prompt)
250
  recipe_text = response.text.strip()
251
-
252
- except Exception as e1:
253
- print("\n⚠ Gemini failed. Switching to GPT-2 fallback.", e1)
254
- recipe_text = generate_recipe_local(ingredient_names)
 
 
255
 
256
  else:
257
- print("\nNo API key → Using GPT-2 fallback")
258
- recipe_text = generate_recipe_local(ingredient_names)
 
259
 
 
260
  return {"ingredients": ingredients, "recipe": recipe_text}
261
 
 
 
262
  except Exception as e:
263
  traceback.print_exc()
264
  raise HTTPException(status_code=500, detail=f"Server Error: {str(e)}")
 
1
  # FastAPI application for Fridge2Dish
2
 
3
+ # import libraries
4
  import os
5
  import io
6
  import time
7
  import traceback
8
+ import threading
9
 
10
  import uvicorn
11
  import numpy as np
 
20
  import tensorflow as tf
21
  import google.generativeai as genai
22
 
23
+ # Transformers libraries (Gemma local fallback)
24
  from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
25
  import torch
 
26
 
27
 
28
+ # create presistent storage for Gemma-2-2b-it model
29
+ LOCAL_GEMMA_DIR = "/data/gemma-2-2b-it"
30
+ REMOTE_GEMMA_NAME = "google/gemma-2-2b-it"
31
+
32
+ # Load ingredients model
33
+ MODEL_PATH = "models/ingredient_model.h5"
34
 
35
+ # Protect loading the large local Gemma model by locking.
36
+ _gemma_lock = threading.Lock()
37
+ _gemma_pipeline = None
38
 
39
 
40
+ # load or download (as applicable) the Gemma model
41
+ def load_or_download_gemma():
42
  """
43
+ Loads a local Gemma-2-2b-it pipeline from LOCAL_GEMMA_DIR if present,
44
+ otherwise downloads from Hugging Face and saves into LOCAL_GEMMA_DIR.
45
+ Returns a transformers text-generation pipeline.
46
  """
47
+ global _gemma_pipeline
48
+ if _gemma_pipeline is not None:
49
+ return _gemma_pipeline
50
 
51
+ with _gemma_lock:
52
+ if _gemma_pipeline is not None:
53
+ return _gemma_pipeline
54
 
55
+ os.makedirs(LOCAL_GEMMA_DIR, exist_ok=True)
 
 
 
 
 
 
 
 
 
 
56
 
57
+ # If local folder already populated, load from there
58
+ if os.listdir(LOCAL_GEMMA_DIR):
59
+ print("\n🔵 Loading Gemma-2-2b-it from local cache:", LOCAL_GEMMA_DIR)
60
+ tokenizer = AutoTokenizer.from_pretrained(LOCAL_GEMMA_DIR, trust_remote_code=True)
61
+ model = AutoModelForCausalLM.from_pretrained(LOCAL_GEMMA_DIR, trust_remote_code=True)
62
  else:
63
+ # Download and save locally
64
+ print("\n🟡 Downloading Gemma-2-2b-it from Hugging Face (first run)...")
65
+ tokenizer = AutoTokenizer.from_pretrained(REMOTE_GEMMA_NAME, trust_remote_code=True)
66
+ model = AutoModelForCausalLM.from_pretrained(REMOTE_GEMMA_NAME, trust_remote_code=True)
67
+ print("\n🟢 Saving Gemma to local persistent directory:", LOCAL_GEMMA_DIR)
68
+ tokenizer.save_pretrained(LOCAL_GEMMA_DIR)
69
+ model.save_pretrained(LOCAL_GEMMA_DIR)
70
+
71
+ # Choose device: GPU if available, otherwise CPU
72
  device = 0 if torch.cuda.is_available() else -1
73
+ print(f"\n[Gemma] creating pipeline (device={device}) -- this may take a moment")
74
+ _gemma_pipeline = pipeline(
75
  "text-generation",
76
  model=model,
77
  tokenizer=tokenizer,
78
  device=device,
79
+ # reduce returned tokens to keep small responses
80
+ max_new_tokens=300,
81
+ do_sample=True,
82
+ top_p=0.95,
83
+ temperature=0.7
84
  )
85
+ print("[Gemma] loaded and ready")
86
+ return _gemma_pipeline
87
 
 
 
88
 
89
 
90
+ # improve LM output by cleaning
91
+ def _clean_generated_text(text: str) -> str:
92
  """
93
+ Basic cleaning of the LM output:
94
+ - remove obvious leading garbage,
95
+ - remove repeated lines,
96
+ - trim long tails after a natural stopping point.
97
  """
98
+ if not text:
99
+ return ""
100
+
101
+ # If model echoes prompt, try to cut at 'Recipe' or '### Ingredients' or similar markers
102
+ markers = ["### Ingredients", "### Steps", "Ingredients:", "Steps:", "Recipe"]
103
+ for m in markers:
104
+ if m in text:
105
+ # keep starting at the marker if there is garbage before
106
+ try:
107
+ idx = text.index(m)
108
+ text = text[idx:]
109
+ break
110
+ except ValueError:
111
+ pass
112
+
113
+ # Deduplicate repeated consecutive lines
114
+ out_lines = []
115
+ prev = None
116
+ for line in text.splitlines():
117
+ s = line.rstrip()
118
+ if s and s == prev:
119
+ continue
120
+ out_lines.append(line)
121
+ prev = s
122
+
123
+ cleaned = "\n".join(out_lines).strip()
124
+ # Trim at a long trailing repeated token if present
125
+ if len(cleaned) > 2000:
126
+ cleaned = cleaned[:2000].rsplit("\n", 1)[0]
127
 
128
+ return cleaned
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
129
 
130
 
131
+ # generate recipe using local Gemma
132
+ def generate_recipe_local_gemma(ingredient_names):
133
  """
134
+ Use local Gemma pipeline to generate a well-formatted recipe in markdown.
135
  """
136
+ gen = load_or_download_gemma()
137
+
138
+ prompt = (
139
+ "You are a professional chef and recipe writer. Create a concise, well-formatted recipe in Markdown "
140
+ f"using ONLY the following ingredients: {', '.join(ingredient_names)}.\n\n"
141
+ "Requirements:\n"
142
+ "- Start with the recipe title on one line.\n"
143
+ "- One-sentence description.\n"
144
+ "- Then a '### Ingredients' section with bullet points and approximate quantities.\n"
145
+ "- Then a '### Steps' section with 6-8 numbered steps.\n"
146
+ "- Keep it concise, no filler, no disclaimers, and end after the steps.\n\n"
147
+ "Output only the recipe in Markdown.\n\nRecipe:\n"
148
+ )
149
+
150
+ out = gen(prompt, do_sample=True, temperature=0.7, top_p=0.95, max_new_tokens=300, num_return_sequences=1)
151
+ generated = out[0].get("generated_text", "")
152
+ # If the model reprints the prompt, remove the leading prompt part:
153
+ if "Recipe:" in generated:
154
+ generated = generated.split("Recipe:", 1)[1].strip()
155
+ cleaned = _clean_generated_text(generated)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
156
  return cleaned
157
 
158
 
159
 
160
+ # Ingredient detection model loading
 
161
  MODEL = tf.keras.models.load_model(MODEL_PATH)
162
 
163
+
164
+ # Class names from folder or manual.
165
+ if os.path.isdir("dataset/dataset_2/train"):
166
+ CLASS_NAMES = sorted(os.listdir("dataset/dataset_2/train"))
167
+
168
+ else:
169
+ CLASS_NAMES = [
170
+ 'apple', 'banana', 'beetroot', 'bell pepper', 'cabbage', 'capsicum', 'carrot', 'cauliflower', 'chilli pepper',
171
+ 'corn', 'cucumber', 'eggplant', 'garlic', 'ginger', 'grapes', 'jalepeno', 'kiwi', 'lemon', 'lettuce', 'mango',
172
+ 'onion', 'orange', 'paprika', 'pear', 'peas', 'pineapple', 'pomegranate', 'potato', 'raddish', 'soy beans',
173
+ 'spinach', 'sweetcorn', 'sweetpotato', 'tomato', 'turnip', 'watermelon']
174
+
175
 
176
  # Infer uploaded image function
177
  def infer_image(pil_image):
178
+ """
179
+ Returns a list of dicts: [{ "name": CapitalizedName, "confidence": 0.xx }, ...]
180
+ """
181
  img = pil_image.resize((224, 224))
182
  arr = np.expand_dims(np.array(img) / 255.0, axis=0)
 
183
  preds = MODEL.predict(arr)[0]
184
+ # Top 3 predictions
185
  top_idxs = np.argsort(preds)[::-1][:3]
186
+ ingredients = []
187
+ for i in top_idxs:
188
+ ingredients.append({"name": CLASS_NAMES[i].capitalize(), "confidence": float(preds[i])})
189
+ if not ingredients:
190
+ return [{"name": "Unknown", "confidence": 0.0}]
191
+ return ingredients
 
 
192
 
193
 
194
  # initialize FastAPI app
195
  app = FastAPI(
196
  title="Fridge2Dish",
197
+ description="Upload an image → Detect ingredients → Generate recipes",
198
+ version="3.0.0"
199
  )
200
 
201
+ # static/templates
202
  app.mount("/static", StaticFiles(directory="static"), name="static")
203
  templates = Jinja2Templates(directory="templates")
204
 
205
  # CORS
206
  app.add_middleware(
207
  CORSMiddleware,
208
+ allow_origins=["*"],
209
+ allow_credentials=True,
210
+ allow_methods=["*"],
211
+ allow_headers=["*"],
212
  )
213
 
214
 
 
226
  async def upload_image(
227
  file: UploadFile = File(...),
228
  user_api_key: str = Form(alias="api_key", default="")
229
+ ):
230
+
231
  try:
232
  if not file.filename.lower().endswith((".jpg", ".jpeg", ".png")):
233
  raise HTTPException(status_code=400, detail="Invalid image format.")
234
+
235
+ # read image
236
+ img_bytes = await file.read()
237
+ pil_img = Image.open(io.BytesIO(img_bytes)).convert("RGB")
238
+
239
+ # detect ingredients
240
+ start = time.time()
241
  ingredients = infer_image(pil_img)
242
+ dur = time.time() - start
243
+ print(f"Detected ingredients: {ingredients} (took {dur:.2f}s)")
244
 
245
+ ingredient_names = [it["name"] for it in ingredients]
246
+
247
+ recipe_text = None
248
+ api_key = user_api_key.strip()
249
+
250
+ # Try server Gemini if api_key provided
251
+ if api_key:
252
  try:
253
+ # Try Gemini first...
254
+ genai.configure(api_key=api_key)
255
  model = genai.GenerativeModel("gemini-2.5-flash")
256
 
257
  prompt = f"""
 
267
 
268
  response = model.generate_content(prompt)
269
  recipe_text = response.text.strip()
270
+ print("\nGemini succeeded.")
271
+
272
+ except Exception as e_gem:
273
+ # Log and fallback to local Gemma
274
+ print("Gemini failed or threw exception; falling back to local Gemma:", e_gem)
275
+ recipe_text = generate_recipe_local_gemma(ingredient_names)
276
 
277
  else:
278
+ # No API key -> local Gemma
279
+ print("\nNo API key provided -> Using local Gemma fallback.")
280
+ recipe_text = generate_recipe_local_gemma(ingredient_names)
281
 
282
+ # Return structured response (ingredients keep confidence)
283
  return {"ingredients": ingredients, "recipe": recipe_text}
284
 
285
+ except HTTPException:
286
+ raise
287
  except Exception as e:
288
  traceback.print_exc()
289
  raise HTTPException(status_code=500, detail=f"Server Error: {str(e)}")