TimInf commited on
Commit
d71c8ab
·
verified ·
1 Parent(s): 265f6ec

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +371 -65
app.py CHANGED
@@ -1,11 +1,11 @@
1
  import gradio as gr
2
- from transformers import FlaxAutoModelForSeq2SeqLM, AutoTokenizer
3
- from transformers import AutoModel
4
  import torch
5
  import numpy as np
6
  import random
 
7
 
8
- # Load models (same as before)
9
  bert_model_name = "alexdseo/RecipeBERT"
10
  bert_tokenizer = AutoTokenizer.from_pretrained(bert_model_name)
11
  bert_model = AutoModel.from_pretrained(bert_model_name)
@@ -15,34 +15,334 @@ MODEL_NAME_OR_PATH = "flax-community/t5-recipe-generation"
15
  t5_tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME_OR_PATH, use_fast=True)
16
  t5_model = FlaxAutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME_OR_PATH)
17
 
18
- # ... (all your existing functions remain the same) ...
19
- # get_embedding, average_embedding, get_cosine_similarity, etc.
 
 
 
 
20
 
21
- def generate_recipe_interface(required_ingredients_text, available_ingredients_text, max_ingredients, max_retries):
22
- """Gradio interface function"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  try:
24
- # Parse ingredient inputs
25
- required_ingredients = [ing.strip() for ing in required_ingredients_text.split(',') if ing.strip()]
26
- available_ingredients = [ing.strip() for ing in available_ingredients_text.split(',') if ing.strip()]
 
 
 
 
 
 
27
 
28
- # Find optimal ingredient combination
 
 
 
 
 
 
 
 
 
 
29
  optimized_ingredients = find_best_ingredients(
30
  required_ingredients,
31
- available_ingredients,
32
  max_ingredients
33
  )
34
 
35
  # Generate recipe
36
  recipe = generate_recipe_with_t5(optimized_ingredients, max_retries)
37
 
38
- # Format output
39
- ingredients_list = '\n'.join([f"• {ing}" for ing in recipe['ingredients']])
40
- directions_list = '\n'.join([f"{i+1}. {dir}" for i, dir in enumerate(recipe['directions'])])
41
- used_ingredients = ', '.join(optimized_ingredients)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
 
43
  return (
44
- recipe['title'],
45
- ingredients_list,
46
  directions_list,
47
  used_ingredients
48
  )
@@ -50,59 +350,65 @@ def generate_recipe_interface(required_ingredients_text, available_ingredients_t
50
  except Exception as e:
51
  return f"Error: {str(e)}", "", "", ""
52
 
53
- # Create Gradio interface
54
  with gr.Blocks(title="AI Recipe Generator") as demo:
55
  gr.Markdown("# 🍳 AI Recipe Generator")
56
  gr.Markdown("Generate recipes using AI with intelligent ingredient combination!")
57
 
58
- with gr.Row():
59
- with gr.Column():
60
- required_ing = gr.Textbox(
61
- label="Required Ingredients (comma-separated)",
62
- placeholder="chicken, rice, onion",
63
- lines=2
64
- )
65
- available_ing = gr.Textbox(
66
- label="Available Ingredients (comma-separated)",
67
- placeholder="garlic, tomato, pepper, herbs",
68
- lines=2
69
- )
70
- max_ing = gr.Slider(
71
- minimum=3,
72
- maximum=10,
73
- value=7,
74
- step=1,
75
- label="Maximum Ingredients"
76
- )
77
- max_retries = gr.Slider(
78
- minimum=1,
79
- maximum=10,
80
- value=5,
81
- step=1,
82
- label="Max Retries"
83
- )
84
- generate_btn = gr.Button("Generate Recipe", variant="primary")
85
 
86
- with gr.Column():
87
- title_output = gr.Textbox(label="Recipe Title", interactive=False)
88
- ingredients_output = gr.Textbox(label="Ingredients", lines=8, interactive=False)
89
- directions_output = gr.Textbox(label="Directions", lines=10, interactive=False)
90
- used_ingredients_output = gr.Textbox(label="Used Ingredients", interactive=False)
91
-
92
- generate_btn.click(
93
- fn=generate_recipe_interface,
94
- inputs=[required_ing, available_ing, max_ing, max_retries],
95
- outputs=[title_output, ingredients_output, directions_output, used_ingredients_output]
96
- )
97
 
98
- # Example
99
- gr.Examples(
100
- examples=[
101
- ["chicken, rice", "onion, garlic, tomato, herbs, pepper", 6, 3],
102
- ["pasta", "cheese, mushrooms, cream, spinach, garlic", 5, 3],
103
- ],
104
- inputs=[required_ing, available_ing, max_ing, max_retries]
105
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
 
107
  if __name__ == "__main__":
108
  demo.launch()
 
1
  import gradio as gr
2
+ from transformers import FlaxAutoModelForSeq2SeqLM, AutoTokenizer, AutoModel
 
3
  import torch
4
  import numpy as np
5
  import random
6
+ import json
7
 
8
+ # Model loading (same as before)
9
  bert_model_name = "alexdseo/RecipeBERT"
10
  bert_tokenizer = AutoTokenizer.from_pretrained(bert_model_name)
11
  bert_model = AutoModel.from_pretrained(bert_model_name)
 
15
  t5_tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME_OR_PATH, use_fast=True)
16
  t5_model = FlaxAutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME_OR_PATH)
17
 
18
+ # Token mapping for T5 model output processing
19
+ special_tokens = t5_tokenizer.all_special_tokens
20
+ tokens_map = {
21
+ "<sep>": "--",
22
+ "<section>": "\n"
23
+ }
24
 
25
+ def get_embedding(text):
26
+ """Computes embedding for a text with Mean Pooling over all tokens"""
27
+ inputs = bert_tokenizer(text, return_tensors="pt", truncation=True, padding=True)
28
+ with torch.no_grad():
29
+ outputs = bert_model(**inputs)
30
+
31
+ # Mean Pooling - take average of all token embeddings
32
+ attention_mask = inputs['attention_mask']
33
+ token_embeddings = outputs.last_hidden_state
34
+ input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
35
+ sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1)
36
+ sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
37
+
38
+ return (sum_embeddings / sum_mask).squeeze(0)
39
+
40
+ def average_embedding(embedding_list):
41
+ """Computes the average of a list of embeddings"""
42
+ tensors = torch.stack([emb for _, emb in embedding_list])
43
+ return tensors.mean(dim=0)
44
+
45
+ def get_cosine_similarity(vec1, vec2):
46
+ """Computes the cosine similarity between two vectors"""
47
+ if torch.is_tensor(vec1):
48
+ vec1 = vec1.detach().numpy()
49
+ if torch.is_tensor(vec2):
50
+ vec2 = vec2.detach().numpy()
51
+
52
+ # Make sure vectors have the right shape (flatten if necessary)
53
+ vec1 = vec1.flatten()
54
+ vec2 = vec2.flatten()
55
+
56
+ dot_product = np.dot(vec1, vec2)
57
+ norm_a = np.linalg.norm(vec1)
58
+ norm_b = np.linalg.norm(vec2)
59
+
60
+ # Avoid division by zero
61
+ if norm_a == 0 or norm_b == 0:
62
+ return 0
63
+
64
+ return dot_product / (norm_a * norm_b)
65
+
66
+ def get_combined_scores(query_vector, embedding_list, all_good_embeddings, avg_weight=0.6):
67
+ """Computes combined score considering both similarity to average and individual ingredients"""
68
+ results = []
69
+
70
+ for name, emb in embedding_list:
71
+ # Similarity to average vector
72
+ avg_similarity = get_cosine_similarity(query_vector, emb)
73
+
74
+ # Average similarity to individual ingredients
75
+ individual_similarities = [get_cosine_similarity(good_emb, emb)
76
+ for _, good_emb in all_good_embeddings]
77
+ avg_individual_similarity = sum(individual_similarities) / len(individual_similarities)
78
+
79
+ # Combined score (weighted average)
80
+ combined_score = avg_weight * avg_similarity + (1 - avg_weight) * avg_individual_similarity
81
+
82
+ results.append((name, emb, combined_score))
83
+
84
+ # Sort by combined score (descending)
85
+ results.sort(key=lambda x: x[2], reverse=True)
86
+ return results
87
+
88
+ def find_best_ingredients(required_ingredients, available_ingredients, max_ingredients=6, avg_weight=0.6):
89
+ """
90
+ Finds the best ingredients based on RecipeBERT embeddings.
91
+ """
92
+ # Ensure no duplicates in lists
93
+ required_ingredients = list(set(required_ingredients))
94
+ available_ingredients = list(set([i for i in available_ingredients if i not in required_ingredients]))
95
+
96
+ # Special case: If no required ingredients, randomly select one from available ingredients
97
+ if not required_ingredients and available_ingredients:
98
+ random_ingredient = random.choice(available_ingredients)
99
+ required_ingredients = [random_ingredient]
100
+ available_ingredients = [i for i in available_ingredients if i != random_ingredient]
101
+
102
+ # If still no ingredients or already at max capacity
103
+ if not required_ingredients or len(required_ingredients) >= max_ingredients:
104
+ return required_ingredients[:max_ingredients]
105
+
106
+ # If no additional ingredients available
107
+ if not available_ingredients:
108
+ return required_ingredients
109
+
110
+ # Calculate embeddings for all ingredients
111
+ embed_required = [(e, get_embedding(e)) for e in required_ingredients]
112
+ embed_available = [(e, get_embedding(e)) for e in available_ingredients]
113
+
114
+ # Number of ingredients to add
115
+ num_to_add = min(max_ingredients - len(required_ingredients), len(available_ingredients))
116
+
117
+ # Copy required ingredients to final list
118
+ final_ingredients = embed_required.copy()
119
+
120
+ # Add best ingredients
121
+ for _ in range(num_to_add):
122
+ # Calculate average vector of current combination
123
+ avg = average_embedding(final_ingredients)
124
+
125
+ # Calculate combined scores for all candidates
126
+ candidates = get_combined_scores(avg, embed_available, final_ingredients, avg_weight)
127
+
128
+ # If no candidates left, break
129
+ if not candidates:
130
+ break
131
+
132
+ # Choose best ingredient
133
+ best_name, best_embedding, _ = candidates[0]
134
+
135
+ # Add best ingredient to final list
136
+ final_ingredients.append((best_name, best_embedding))
137
+
138
+ # Remove ingredient from available ingredients
139
+ embed_available = [item for item in embed_available if item[0] != best_name]
140
+
141
+ # Extract only ingredient names
142
+ return [name for name, _ in final_ingredients]
143
+
144
+ def skip_special_tokens(text, special_tokens):
145
+ """Removes special tokens from text"""
146
+ for token in special_tokens:
147
+ text = text.replace(token, "")
148
+ return text
149
+
150
+ def target_postprocessing(texts, special_tokens):
151
+ """Post-processes generated text"""
152
+ if not isinstance(texts, list):
153
+ texts = [texts]
154
+
155
+ new_texts = []
156
+ for text in texts:
157
+ text = skip_special_tokens(text, special_tokens)
158
+
159
+ for k, v in tokens_map.items():
160
+ text = text.replace(k, v)
161
+
162
+ new_texts.append(text)
163
+
164
+ return new_texts
165
+
166
+ def validate_recipe_ingredients(recipe_ingredients, expected_ingredients, tolerance=0):
167
+ """Validates if the recipe contains approximately the expected ingredients."""
168
+ recipe_count = len([ing for ing in recipe_ingredients if ing and ing.strip()])
169
+ expected_count = len(expected_ingredients)
170
+ return abs(recipe_count - expected_count) == tolerance
171
+
172
+ def generate_recipe_with_t5(ingredients_list, max_retries=5):
173
+ """Generates a recipe using the T5 recipe generation model with validation."""
174
+ original_ingredients = ingredients_list.copy()
175
+
176
+ for attempt in range(max_retries):
177
+ try:
178
+ # For retries after the first attempt, shuffle the ingredients
179
+ if attempt > 0:
180
+ current_ingredients = original_ingredients.copy()
181
+ random.shuffle(current_ingredients)
182
+ else:
183
+ current_ingredients = ingredients_list
184
+
185
+ # Format ingredients as a comma-separated string
186
+ ingredients_string = ", ".join(current_ingredients)
187
+ prefix = "items: "
188
+
189
+ # Generation settings
190
+ generation_kwargs = {
191
+ "max_length": 512,
192
+ "min_length": 64,
193
+ "do_sample": True,
194
+ "top_k": 60,
195
+ "top_p": 0.95
196
+ }
197
+
198
+ # Tokenize input
199
+ inputs = t5_tokenizer(
200
+ prefix + ingredients_string,
201
+ max_length=256,
202
+ padding="max_length",
203
+ truncation=True,
204
+ return_tensors="jax"
205
+ )
206
+
207
+ # Generate text
208
+ output_ids = t5_model.generate(
209
+ input_ids=inputs.input_ids,
210
+ attention_mask=inputs.attention_mask,
211
+ **generation_kwargs
212
+ )
213
+
214
+ # Decode and post-process
215
+ generated = output_ids.sequences
216
+ generated_text = target_postprocessing(
217
+ t5_tokenizer.batch_decode(generated, skip_special_tokens=False),
218
+ special_tokens
219
+ )[0]
220
+
221
+ # Parse sections
222
+ recipe = {}
223
+ sections = generated_text.split("\n")
224
+ for section in sections:
225
+ section = section.strip()
226
+ if section.startswith("title:"):
227
+ recipe["title"] = section.replace("title:", "").strip().capitalize()
228
+ elif section.startswith("ingredients:"):
229
+ ingredients_text = section.replace("ingredients:", "").strip()
230
+ recipe["ingredients"] = [item.strip().capitalize() for item in ingredients_text.split("--") if item.strip()]
231
+ elif section.startswith("directions:"):
232
+ directions_text = section.replace("directions:", "").strip()
233
+ recipe["directions"] = [step.strip().capitalize() for step in directions_text.split("--") if step.strip()]
234
+
235
+ # If title is missing, create one
236
+ if "title" not in recipe:
237
+ recipe["title"] = f"Recipe with {', '.join(current_ingredients[:3])}"
238
+
239
+ # Ensure all sections exist
240
+ if "ingredients" not in recipe:
241
+ recipe["ingredients"] = current_ingredients
242
+ if "directions" not in recipe:
243
+ recipe["directions"] = ["No directions generated"]
244
+
245
+ # Validate the recipe
246
+ if validate_recipe_ingredients(recipe["ingredients"], original_ingredients):
247
+ return recipe
248
+ else:
249
+ if attempt == max_retries - 1:
250
+ return recipe
251
+
252
+ except Exception as e:
253
+ if attempt == max_retries - 1:
254
+ return {
255
+ "title": f"Recipe with {original_ingredients[0] if original_ingredients else 'ingredients'}",
256
+ "ingredients": original_ingredients,
257
+ "directions": ["Error generating recipe instructions"]
258
+ }
259
+
260
+ # Fallback
261
+ return {
262
+ "title": f"Recipe with {original_ingredients[0] if original_ingredients else 'ingredients'}",
263
+ "ingredients": original_ingredients,
264
+ "directions": ["Error generating recipe instructions"]
265
+ }
266
+
267
+ def flutter_api_generate_recipe(ingredients_data):
268
+ """
269
+ Flutter-friendly API function that processes JSON input
270
+ and returns structured JSON output matching your original Flask API
271
+ """
272
  try:
273
+ # Parse input - handle both string and dict formats
274
+ if isinstance(ingredients_data, str):
275
+ data = json.loads(ingredients_data)
276
+ else:
277
+ data = ingredients_data
278
+
279
+ # Extract parameters (same as your Flask API)
280
+ required_ingredients = data.get('required_ingredients', [])
281
+ available_ingredients = data.get('available_ingredients', [])
282
 
283
+ # Backward compatibility
284
+ if data.get('ingredients') and not required_ingredients:
285
+ required_ingredients = data.get('ingredients', [])
286
+
287
+ max_ingredients = data.get('max_ingredients', 7)
288
+ max_retries = data.get('max_retries', 5)
289
+
290
+ if not required_ingredients and not available_ingredients:
291
+ return json.dumps({"error": "No ingredients provided"})
292
+
293
+ # Find optimal ingredients
294
  optimized_ingredients = find_best_ingredients(
295
  required_ingredients,
296
+ available_ingredients,
297
  max_ingredients
298
  )
299
 
300
  # Generate recipe
301
  recipe = generate_recipe_with_t5(optimized_ingredients, max_retries)
302
 
303
+ # Return in exact same format as your Flask API
304
+ result = {
305
+ 'title': recipe['title'],
306
+ 'ingredients': recipe['ingredients'],
307
+ 'directions': recipe['directions'],
308
+ 'used_ingredients': optimized_ingredients
309
+ }
310
+
311
+ return json.dumps(result)
312
+
313
+ except Exception as e:
314
+ return json.dumps({"error": f"Error in recipe generation: {str(e)}"})
315
+
316
+ def gradio_ui_generate_recipe(required_ingredients_text, available_ingredients_text, max_ingredients, max_retries):
317
+ """Gradio UI function for web interface"""
318
+ try:
319
+ # Parse text inputs
320
+ required_ingredients = [ing.strip() for ing in required_ingredients_text.split(',') if ing.strip()]
321
+ available_ingredients = [ing.strip() for ing in available_ingredients_text.split(',') if ing.strip()]
322
+
323
+ # Create data dict in Flutter API format
324
+ data = {
325
+ 'required_ingredients': required_ingredients,
326
+ 'available_ingredients': available_ingredients,
327
+ 'max_ingredients': max_ingredients,
328
+ 'max_retries': max_retries
329
+ }
330
+
331
+ # Use the same function as Flutter API
332
+ result_json = flutter_api_generate_recipe(data)
333
+ result = json.loads(result_json)
334
+
335
+ if 'error' in result:
336
+ return result['error'], "", "", ""
337
+
338
+ # Format for Gradio display
339
+ ingredients_list = '\n'.join([f"• {ing}" for ing in result['ingredients']])
340
+ directions_list = '\n'.join([f"{i+1}. {dir}" for i, dir in enumerate(result['directions'])])
341
+ used_ingredients = ', '.join(result['used_ingredients'])
342
 
343
  return (
344
+ result['title'],
345
+ ingredients_list,
346
  directions_list,
347
  used_ingredients
348
  )
 
350
  except Exception as e:
351
  return f"Error: {str(e)}", "", "", ""
352
 
353
+ # Create Gradio Interface
354
  with gr.Blocks(title="AI Recipe Generator") as demo:
355
  gr.Markdown("# 🍳 AI Recipe Generator")
356
  gr.Markdown("Generate recipes using AI with intelligent ingredient combination!")
357
 
358
+ with gr.Tab("Web Interface"):
359
+ with gr.Row():
360
+ with gr.Column():
361
+ required_ing = gr.Textbox(
362
+ label="Required Ingredients (comma-separated)",
363
+ placeholder="chicken, rice, onion",
364
+ lines=2
365
+ )
366
+ available_ing = gr.Textbox(
367
+ label="Available Ingredients (comma-separated)",
368
+ placeholder="garlic, tomato, pepper, herbs",
369
+ lines=2
370
+ )
371
+ max_ing = gr.Slider(3, 10, value=7, step=1, label="Maximum Ingredients")
372
+ max_retries = gr.Slider(1, 10, value=5, step=1, label="Max Retries")
373
+ generate_btn = gr.Button("Generate Recipe", variant="primary")
374
+
375
+ with gr.Column():
376
+ title_output = gr.Textbox(label="Recipe Title", interactive=False)
377
+ ingredients_output = gr.Textbox(label="Ingredients", lines=8, interactive=False)
378
+ directions_output = gr.Textbox(label="Directions", lines=10, interactive=False)
379
+ used_ingredients_output = gr.Textbox(label="Used Ingredients", interactive=False)
 
 
 
 
 
380
 
381
+ generate_btn.click(
382
+ fn=gradio_ui_generate_recipe,
383
+ inputs=[required_ing, available_ing, max_ing, max_retries],
384
+ outputs=[title_output, ingredients_output, directions_output, used_ingredients_output]
385
+ )
 
 
 
 
 
 
386
 
387
+ with gr.Tab("API Testing"):
388
+ gr.Markdown("### Test the Flutter API")
389
+ gr.Markdown("This tab uses the same function that Flutter apps will call via API")
390
+
391
+ api_input = gr.Textbox(
392
+ label="JSON Input (Flutter API Format)",
393
+ placeholder='{"required_ingredients": ["chicken", "rice"], "available_ingredients": ["onion", "garlic"], "max_ingredients": 6}',
394
+ lines=4
395
+ )
396
+ api_output = gr.Textbox(label="JSON Output", lines=15, interactive=False)
397
+ api_test_btn = gr.Button("Test API", variant="secondary")
398
+
399
+ api_test_btn.click(
400
+ fn=flutter_api_generate_recipe,
401
+ inputs=[api_input],
402
+ outputs=[api_output]
403
+ )
404
+
405
+ gr.Examples(
406
+ examples=[
407
+ ['{"required_ingredients": ["chicken", "rice"], "available_ingredients": ["onion", "garlic", "tomato"], "max_ingredients": 6}'],
408
+ ['{"ingredients": ["pasta"], "available_ingredients": ["cheese", "mushrooms", "cream"], "max_ingredients": 5}']
409
+ ],
410
+ inputs=[api_input]
411
+ )
412
 
413
  if __name__ == "__main__":
414
  demo.launch()