Wills17 commited on
Commit
d148f3b
·
verified ·
1 Parent(s): 896c9d4

Update FastAPI_app.py

Browse files
Files changed (1) hide show
  1. FastAPI_app.py +144 -187
FastAPI_app.py CHANGED
@@ -1,4 +1,5 @@
1
  # FastAPI application for Fridge2Dish
 
2
 
3
  # import libraries
4
  import os
@@ -20,166 +21,32 @@ from fastapi.middleware.cors import CORSMiddleware
20
  import tensorflow as tf
21
  import google.generativeai as genai
22
 
23
- # Transformers libraries (Gemma local fallback)
24
- from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
25
- import torch
26
 
 
 
 
 
 
27
 
28
 
29
- # Gemma model download status
30
- GEMMA_STATUS = {
31
- "downloading": False,
32
- "completed": False,
33
- "error": None
34
- }
 
 
 
 
35
 
36
- # create presistent storage for Gemma-2b-it model
37
- LOCAL_GEMMA_DIR = "/data/gemma-2b-it"
38
- GEMMA_MODEL_NAME = "google/gemma-2b-it"
39
 
40
- # Load ingredients model
41
  MODEL_PATH = "models/ingredient_model.h5"
 
 
42
 
43
- # Protect loading the large local Gemma model by locking.
44
- _local_lock = threading.Lock()
45
- _local_generator = None
46
-
47
-
48
- # load or download (as applicable) the Gemma model
49
- def load_or_download_gemma():
50
-
51
- global _local_generator, GEMMA_STATUS
52
- if _local_generator is not None:
53
- return _local_generator
54
-
55
- with _local_lock:
56
- if _local_generator is not None:
57
- return _local_generator
58
-
59
- os.makedirs(LOCAL_GEMMA_DIR, exist_ok=True)
60
-
61
- try:
62
- # Mark download start
63
- if not os.listdir(LOCAL_GEMMA_DIR):
64
- GEMMA_STATUS["downloading"] = True
65
- GEMMA_STATUS["completed"] = False
66
- GEMMA_STATUS["error"] = None
67
- print("\n🟡 Downloading Gemma-2-2b-it from Hugging Face (first run)...")
68
-
69
- tokenizer = AutoTokenizer.from_pretrained(GEMMA_MODEL_NAME)
70
- model = AutoModelForCausalLM.from_pretrained(GEMMA_MODEL_NAME)
71
-
72
- print("\n🟢 Saving Gemma model to persistent storage…")
73
- tokenizer.save_pretrained(LOCAL_GEMMA_DIR)
74
- model.save_pretrained(LOCAL_GEMMA_DIR)
75
-
76
- else:
77
- print("\n🔵 Loading Gemma from local cache…")
78
- tokenizer = AutoTokenizer.from_pretrained(LOCAL_GEMMA_DIR)
79
- model = AutoModelForCausalLM.from_pretrained(LOCAL_GEMMA_DIR)
80
-
81
- GEMMA_STATUS["downloading"] = False
82
- GEMMA_STATUS["completed"] = True
83
-
84
- except Exception as e:
85
- GEMMA_STATUS["downloading"] = False
86
- GEMMA_STATUS["completed"] = False
87
- GEMMA_STATUS["error"] = str(e)
88
- raise e
89
-
90
- # Choose device: GPU if available, otherwise CPU
91
- device = 0 if torch.cuda.is_available() else -1
92
- print(f"\n[Gemma] creating pipeline (device={device}) -- this may take a moment")
93
-
94
- _local_generator = pipeline(
95
- "text-generation",
96
- model=model,
97
- tokenizer=tokenizer,
98
- device=device,
99
- # reduce returned tokens to keep small responses
100
- max_new_tokens=300,
101
- do_sample=True,
102
- top_p=0.95,
103
- temperature=0.7
104
- )
105
-
106
- print("\n\n✅ Gemma ready for generation.")
107
- return _local_generator
108
-
109
-
110
-
111
- # improve LM output by cleaning
112
- def _clean_generated_text(text: str) -> str:
113
- """
114
- Basic cleaning of the LM output:
115
- - remove obvious leading garbage,
116
- - remove repeated lines,
117
- - trim long tails after a natural stopping point.
118
- """
119
- if not text:
120
- return ""
121
-
122
- # If model echoes prompt, try to cut at 'Recipe' or '### Ingredients' or similar markers
123
- markers = ["### Ingredients", "### Steps", "Ingredients:", "Steps:", "Recipe"]
124
- for m in markers:
125
- if m in text:
126
- # keep starting at the marker if there is garbage before
127
- try:
128
- idx = text.index(m)
129
- text = text[idx:]
130
- break
131
- except ValueError:
132
- pass
133
-
134
- # Deduplicate repeated consecutive lines
135
- out_lines = []
136
- prev = None
137
- for line in text.splitlines():
138
- s = line.rstrip()
139
- if s and s == prev:
140
- continue
141
- out_lines.append(line)
142
- prev = s
143
-
144
- cleaned = "\n".join(out_lines).strip()
145
- # Trim at a long trailing repeated token if present
146
- if len(cleaned) > 2000:
147
- cleaned = cleaned[:2000].rsplit("\n", 1)[0]
148
-
149
- return cleaned
150
-
151
-
152
- # generate recipe using local Gemma
153
- def generate_recipe_local_gemma(ingredient_names):
154
- """
155
- Use local Gemma pipeline to generate a well-formatted recipe in markdown.
156
- """
157
- gen = load_or_download_gemma()
158
-
159
- prompt = (
160
- "You are a professional chef and recipe writer. Create a concise, well-formatted recipe in Markdown "
161
- f"using ONLY the following ingredients: {', '.join(ingredient_names)}.\n\n"
162
- "Requirements:\n"
163
- "- Start with the recipe title on one line.\n"
164
- "- One-sentence description.\n"
165
- "- Then a '### Ingredients' section with bullet points and approximate quantities.\n"
166
- "- Then a '### Steps' section with 6-8 numbered steps.\n"
167
- "- Keep it concise, no filler, no disclaimers, and end after the steps.\n\n"
168
- "Output only the recipe in Markdown.\n\nRecipe:\n"
169
- )
170
-
171
- out = gen(prompt, do_sample=True, temperature=0.7, top_p=0.95, max_new_tokens=300, num_return_sequences=1)
172
- generated = out[0].get("generated_text", "")
173
-
174
- # If the model reprints the prompt, remove the leading prompt part:
175
- if "Recipe:" in generated:
176
- generated = generated.split("Recipe:", 1)[1].strip()
177
- cleaned = _clean_generated_text(generated)
178
- return cleaned
179
-
180
-
181
-
182
- # Ingredient detection model loading
183
  MODEL = tf.keras.models.load_model(MODEL_PATH)
184
 
185
 
@@ -203,16 +70,111 @@ def infer_image(pil_image):
203
  img = pil_image.resize((224, 224))
204
  arr = np.expand_dims(np.array(img) / 255.0, axis=0)
205
  preds = MODEL.predict(arr)[0]
206
- # Top 3 predictions
207
- top_idxs = np.argsort(preds)[::-1][:3]
208
  ingredients = []
209
  for i in top_idxs:
210
- ingredients.append({"name": CLASS_NAMES[i].capitalize(), "confidence": float(preds[i])})
 
 
 
211
  if not ingredients:
212
  return [{"name": "Unknown", "confidence": 0.0}]
213
  return ingredients
214
 
215
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
216
  # initialize FastAPI app
217
  app = FastAPI(
218
  title="Fridge2Dish",
@@ -237,19 +199,6 @@ app.add_middleware(
237
 
238
  # ROUTES
239
 
240
- # Gemma model download status tracking
241
- @app.get("/model-status")
242
- def model_status():
243
- """
244
- This function reports whether Gemma fallback model is downloaded, downloading, or errored.
245
- """
246
- return {
247
- "downloading": GEMMA_STATUS["downloading"],
248
- "completed": GEMMA_STATUS["completed"],
249
- "error": GEMMA_STATUS["error"]
250
- }
251
-
252
-
253
  # Home Route
254
  @app.get("/", response_class=HTMLResponse)
255
  def home(request: Request):
@@ -267,28 +216,27 @@ async def upload_image(
267
  if not file.filename.lower().endswith((".jpg", ".jpeg", ".png")):
268
  raise HTTPException(status_code=400, detail="Invalid image format.")
269
 
270
- # read image
271
  img_bytes = await file.read()
272
  pil_img = Image.open(io.BytesIO(img_bytes)).convert("RGB")
273
 
274
  # detect ingredients
275
  start = time.time()
276
  ingredients = infer_image(pil_img)
277
- dur = time.time() - start
278
- print(f"Detected ingredients: {ingredients} (took {dur:.2f}s)")
279
 
280
- ingredient_names = [it["name"] for it in ingredients]
281
 
282
  recipe_text = None
283
- api_key = user_api_key.strip()
284
 
285
- # Try server Gemini if api_key provided
286
  if api_key:
 
287
  try:
288
- # Try Gemini first...
289
  genai.configure(api_key=api_key)
290
  model = genai.GenerativeModel("gemini-2.5-flash")
291
-
292
  prompt = f"""
293
  You are an AI chef. Create a short recipe using only: {', '.join(ingredient_names)}.
294
  Include:
@@ -299,30 +247,39 @@ async def upload_image(
299
  - Optional fun tips or variations
300
  Return results in markdown format.
301
  """
302
-
 
303
  response = model.generate_content(prompt)
304
  recipe_text = response.text.strip()
305
  print("\nGemini succeeded.")
306
 
307
- except Exception as e_gem:
308
- # Log and fallback to local Gemma
309
- print("Gemini failed or threw exception; falling back to local Gemma:", e_gem)
310
- recipe_text = generate_recipe_local_gemma(ingredient_names)
 
 
 
 
311
 
312
  else:
313
- # No API key -> local Gemma
314
- print("\nNo API key provided -> Using local Gemma fallback.")
315
- recipe_text = generate_recipe_local_gemma(ingredient_names)
 
 
 
 
316
 
317
- # Return structured response (ingredients keep confidence)
318
  return {"ingredients": ingredients, "recipe": recipe_text}
319
 
320
  except HTTPException:
 
321
  raise
322
  except Exception as e:
323
  traceback.print_exc()
324
  raise HTTPException(status_code=500, detail=f"Server Error: {str(e)}")
325
-
326
 
327
  # Health check
328
  @app.get("/health")
 
1
  # FastAPI application for Fridge2Dish
2
+ # Fallback: OpenChef-3B-v2 (GGUF) via llama-cpp-python
3
 
4
  # import libraries
5
  import os
 
21
  import tensorflow as tf
22
  import google.generativeai as genai
23
 
24
+ # llama-cpp-python for GGUF fallback
 
 
25
 
26
+ try:
27
+ from llama_cpp import Llama
28
+ except Exception as e:
29
+ Llama = None
30
+ print("Warning: llama_cpp not available. Install llama-cpp-python to use local OpenChef fallback.", e)
31
 
32
 
33
+ # -----------------------------
34
+ # CONFIG — adjust this path
35
+ # -----------------------------
36
+ # Set LOCAL_GGUF_PATH to the path of your OpenChef-3B-v2 GGUF file that you've
37
+ # uploaded into the repo/persistent storage. Example:
38
+ # LOCAL_GGUF_PATH = "/data/OpenChef-3B-v2.Q4_K_M.gguf"
39
+ #
40
+ # Developer note: replace the value below with the actual uploaded file path.
41
+ LOCAL_GGUF_PATH = "models/OpenChef-3B-v2.Q4_K_M.gguf"
42
+ # -----------------------------
43
 
 
 
 
44
 
45
+ # Ingredient model (load once)
46
  MODEL_PATH = "models/ingredient_model.h5"
47
+ if not os.path.exists(MODEL_PATH):
48
+ raise FileNotFoundError(f"Ingredient model not found at {MODEL_PATH}")
49
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
  MODEL = tf.keras.models.load_model(MODEL_PATH)
51
 
52
 
 
70
  img = pil_image.resize((224, 224))
71
  arr = np.expand_dims(np.array(img) / 255.0, axis=0)
72
  preds = MODEL.predict(arr)[0]
73
+ top_idxs = np.argsort(preds)[::-1][:5]
 
74
  ingredients = []
75
  for i in top_idxs:
76
+ ingredients.append({
77
+ "name": CLASS_NAMES[i].capitalize(),
78
+ "confidence": float(preds[i])
79
+ })
80
  if not ingredients:
81
  return [{"name": "Unknown", "confidence": 0.0}]
82
  return ingredients
83
 
84
 
85
+ # Protect loading by locking.
86
+ _llama_lock = threading.Lock()
87
+ _llama_model = None
88
+
89
+
90
+ def load_local_openchef():
91
+ """Load the OpenChef GGUF via llama-cpp-python. Thread-safe and cached."""
92
+ global _llama_model
93
+ if _llama_model is not None:
94
+ return _llama_model
95
+
96
+ if Llama is None:
97
+ raise RuntimeError("llama_cpp is not installed. Install 'llama-cpp-python' to use local OpenChef fallback.")
98
+
99
+ with _llama_lock:
100
+ if _llama_model is not None:
101
+ return _llama_model
102
+
103
+ if not os.path.exists(LOCAL_GGUF_PATH):
104
+ # be explicit about missing model
105
+ raise FileNotFoundError(
106
+ f"Local OpenChef GGUF not found at {LOCAL_GGUF_PATH}. "
107
+ "Place the .gguf file there or update LOCAL_GGUF_PATH."
108
+ )
109
+
110
+ # instantiate; adjust n_ctx if needed
111
+ print(f"[openchef] Loading GGUF model from {LOCAL_GGUF_PATH} ...")
112
+ _llama_model = Llama(model_path=LOCAL_GGUF_PATH, n_ctx=2048)
113
+ print("[openchef] Loaded.")
114
+ return _llama_model
115
+
116
+
117
+ def generate_recipe_local_openchef(ingredient_names: list, max_tokens: int = 512, temperature: float = 0.7):
118
+ """
119
+ Generate a markdown recipe using the local OpenChef (GGUF).
120
+ Returns plain text (markdown).
121
+ """
122
+ llama = load_local_openchef()
123
+
124
+ # clean ingredient list string
125
+ ing_str = ", ".join(ingredient_names)
126
+
127
+ prompt = f"""You are a concise AI chef. Use ONLY these ingredients: {ing_str}
128
+
129
+ Rules:
130
+ - Title on one line.
131
+ - One-sentence description.
132
+ - "### Ingredients" followed by a bullet list with approximate quantities.
133
+ - "### Steps" followed by 6-8 numbered concise steps.
134
+ - Optionally a "Tip:" line at the end.
135
+ - No extra commentary, no apologias. Return only the recipe in markdown.
136
+
137
+ Recipe:
138
+ """
139
+
140
+ # llama-cpp-python returns dict with 'choices' etc or direct text depending on version
141
+ # Use completion with stop tokens to keep output concise.
142
+ try:
143
+ resp = llama.create(
144
+ prompt=prompt,
145
+ max_tokens=max_tokens,
146
+ temperature=temperature,
147
+ top_p=0.95,
148
+ stop=["\n\n\n"]
149
+ )
150
+ except TypeError:
151
+ # older/newer llama-cpp-python API differences
152
+ resp = llama(prompt, max_tokens=max_tokens, temperature=temperature)
153
+
154
+ # extract text
155
+ # resp may be dict-like: {'choices': [{'text': '...'}], ...}
156
+ text = ""
157
+ try:
158
+ if isinstance(resp, dict) and "choices" in resp:
159
+ # new style
160
+ text = resp["choices"][0].get("text", "").strip()
161
+ elif hasattr(resp, "choices"):
162
+ text = resp.choices[0].text.strip()
163
+ elif isinstance(resp, str):
164
+ text = resp.strip()
165
+ else:
166
+ # fallback, str conversion
167
+ text = str(resp).strip()
168
+ except Exception:
169
+ text = str(resp).strip()
170
+
171
+ # sanity clean: if the model repeated the prompt, strip it
172
+ if text.startswith("Recipe:"):
173
+ text = text.split("Recipe:", 1)[1].strip()
174
+
175
+ return text
176
+
177
+
178
  # initialize FastAPI app
179
  app = FastAPI(
180
  title="Fridge2Dish",
 
199
 
200
  # ROUTES
201
 
 
 
 
 
 
 
 
 
 
 
 
 
 
202
  # Home Route
203
  @app.get("/", response_class=HTMLResponse)
204
  def home(request: Request):
 
216
  if not file.filename.lower().endswith((".jpg", ".jpeg", ".png")):
217
  raise HTTPException(status_code=400, detail="Invalid image format.")
218
 
219
+ # load image
220
  img_bytes = await file.read()
221
  pil_img = Image.open(io.BytesIO(img_bytes)).convert("RGB")
222
 
223
  # detect ingredients
224
  start = time.time()
225
  ingredients = infer_image(pil_img)
226
+ end = time.time()
227
+ print(f"Detected ingredients: {ingredients} (took {end-start:.2f}s)")
228
 
229
+ ingredient_names = [i["name"] for i in ingredients]
230
 
231
  recipe_text = None
232
+ api_key = (user_api_key or "").strip()
233
 
 
234
  if api_key:
235
+ # try Gemini first
236
  try:
 
237
  genai.configure(api_key=api_key)
238
  model = genai.GenerativeModel("gemini-2.5-flash")
239
+
240
  prompt = f"""
241
  You are an AI chef. Create a short recipe using only: {', '.join(ingredient_names)}.
242
  Include:
 
247
  - Optional fun tips or variations
248
  Return results in markdown format.
249
  """
250
+
251
+ print("Trying Gemini...")
252
  response = model.generate_content(prompt)
253
  recipe_text = response.text.strip()
254
  print("\nGemini succeeded.")
255
 
256
+ except Exception as e_gemini:
257
+ print("\nGemini failed:", e_gemini)
258
+ # fallback to local OpenChef
259
+ try:
260
+ recipe_text = generate_recipe_local_openchef(ingredient_names)
261
+ except Exception as e_local:
262
+ print("\nLocal OpenChef failed:", e_local)
263
+ raise e_local
264
 
265
  else:
266
+ # no API key: use local OpenChef fallback
267
+ try:
268
+ print("\nNo API key provided —> Using local OpenChef fallback.")
269
+ recipe_text = generate_recipe_local_openchef(ingredient_names)
270
+ except Exception as e_local:
271
+ print("Local OpenChef failed:", e_local)
272
+ raise e_local
273
 
 
274
  return {"ingredients": ingredients, "recipe": recipe_text}
275
 
276
  except HTTPException:
277
+ # re-raise known HTTP errors
278
  raise
279
  except Exception as e:
280
  traceback.print_exc()
281
  raise HTTPException(status_code=500, detail=f"Server Error: {str(e)}")
282
+
283
 
284
  # Health check
285
  @app.get("/health")