TimInf commited on
Commit
1787e4b
·
verified ·
1 Parent(s): 585d2b5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +108 -387
app.py CHANGED
@@ -1,387 +1,108 @@
1
- from flask import Flask, request, jsonify
2
- from transformers import FlaxAutoModelForSeq2SeqLM, AutoTokenizer
3
- from transformers import AutoModel
4
- import torch
5
- import numpy as np
6
- import random
7
- from flask_cors import CORS
8
-
9
- app = Flask(__name__)
10
- CORS(app)
11
-
12
- # Load RecipeBERT model (for semantic ingredient combination)
13
- bert_model_name = "alexdseo/RecipeBERT"
14
- bert_tokenizer = AutoTokenizer.from_pretrained(bert_model_name)
15
- bert_model = AutoModel.from_pretrained(bert_model_name)
16
- bert_model.eval()
17
-
18
- # Load T5 recipe generation model
19
- MODEL_NAME_OR_PATH = "flax-community/t5-recipe-generation"
20
- t5_tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME_OR_PATH, use_fast=True)
21
- t5_model = FlaxAutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME_OR_PATH)
22
-
23
- # Token mapping for T5 model output processing
24
- special_tokens = t5_tokenizer.all_special_tokens
25
- tokens_map = {
26
- "<sep>": "--",
27
- "<section>": "\n"
28
- }
29
-
30
-
31
- def get_embedding(text):
32
- """Computes embedding for a text with Mean Pooling over all tokens"""
33
- inputs = bert_tokenizer(text, return_tensors="pt", truncation=True, padding=True)
34
- with torch.no_grad():
35
- outputs = bert_model(**inputs)
36
-
37
- # Mean Pooling - take average of all token embeddings
38
- attention_mask = inputs['attention_mask']
39
- token_embeddings = outputs.last_hidden_state
40
- input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
41
- sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1)
42
- sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
43
-
44
- return (sum_embeddings / sum_mask).squeeze(0)
45
-
46
-
47
- def average_embedding(embedding_list):
48
- """Computes the average of a list of embeddings"""
49
- tensors = torch.stack([emb for _, emb in embedding_list])
50
- return tensors.mean(dim=0)
51
-
52
-
53
- def get_cosine_similarity(vec1, vec2):
54
- """Computes the cosine similarity between two vectors"""
55
- if torch.is_tensor(vec1):
56
- vec1 = vec1.detach().numpy()
57
- if torch.is_tensor(vec2):
58
- vec2 = vec2.detach().numpy()
59
-
60
- # Make sure vectors have the right shape (flatten if necessary)
61
- vec1 = vec1.flatten()
62
- vec2 = vec2.flatten()
63
-
64
- dot_product = np.dot(vec1, vec2)
65
- norm_a = np.linalg.norm(vec1)
66
- norm_b = np.linalg.norm(vec2)
67
-
68
- # Avoid division by zero
69
- if norm_a == 0 or norm_b == 0:
70
- return 0
71
-
72
- return dot_product / (norm_a * norm_b)
73
-
74
-
75
- def get_combined_scores(query_vector, embedding_list, all_good_embeddings, avg_weight=0.6):
76
- """Computes combined score considering both similarity to average and individual ingredients"""
77
- results = []
78
-
79
- for name, emb in embedding_list:
80
- # Similarity to average vector
81
- avg_similarity = get_cosine_similarity(query_vector, emb)
82
-
83
- # Average similarity to individual ingredients
84
- individual_similarities = [get_cosine_similarity(good_emb, emb)
85
- for _, good_emb in all_good_embeddings]
86
- avg_individual_similarity = sum(individual_similarities) / len(individual_similarities)
87
-
88
- # Combined score (weighted average)
89
- combined_score = avg_weight * avg_similarity + (1 - avg_weight) * avg_individual_similarity
90
-
91
- results.append((name, emb, combined_score))
92
-
93
- # Sort by combined score (descending)
94
- results.sort(key=lambda x: x[2], reverse=True)
95
- return results
96
-
97
-
98
- def find_best_ingredients(required_ingredients, available_ingredients, max_ingredients=6, avg_weight=0.6):
99
- """
100
- Finds the best ingredients based on RecipeBERT embeddings.
101
-
102
- Args:
103
- required_ingredients (list): Required ingredients that must be used
104
- available_ingredients (list): Available ingredients to choose from
105
- max_ingredients (int): Maximum number of ingredients for the recipe
106
- avg_weight (float): Weight for average vector
107
-
108
- Returns:
109
- list: The optimal combination of ingredients
110
- """
111
- # Ensure no duplicates in lists
112
- required_ingredients = list(set(required_ingredients))
113
- available_ingredients = list(set([i for i in available_ingredients if i not in required_ingredients]))
114
-
115
- # Special case: If no required ingredients, randomly select one from available ingredients
116
- if not required_ingredients and available_ingredients:
117
- # Randomly select 1 ingredient as starting point
118
- random_ingredient = random.choice(available_ingredients)
119
- required_ingredients = [random_ingredient]
120
- available_ingredients = [i for i in available_ingredients if i != random_ingredient]
121
- print(f"No required ingredients provided. Randomly selected: {random_ingredient}")
122
-
123
- # If still no ingredients or already at max capacity
124
- if not required_ingredients or len(required_ingredients) >= max_ingredients:
125
- return required_ingredients[:max_ingredients]
126
-
127
- # If no additional ingredients available
128
- if not available_ingredients:
129
- return required_ingredients
130
-
131
- # Calculate embeddings for all ingredients
132
- embed_required = [(e, get_embedding(e)) for e in required_ingredients]
133
- embed_available = [(e, get_embedding(e)) for e in available_ingredients]
134
-
135
- # Number of ingredients to add
136
- num_to_add = min(max_ingredients - len(required_ingredients), len(available_ingredients))
137
-
138
- # Copy required ingredients to final list
139
- final_ingredients = embed_required.copy()
140
-
141
- # Add best ingredients
142
- for _ in range(num_to_add):
143
- # Calculate average vector of current combination
144
- avg = average_embedding(final_ingredients)
145
-
146
- # Calculate combined scores for all candidates
147
- candidates = get_combined_scores(avg, embed_available, final_ingredients, avg_weight)
148
-
149
- # If no candidates left, break
150
- if not candidates:
151
- break
152
-
153
- # Choose best ingredient
154
- best_name, best_embedding, _ = candidates[0]
155
-
156
- # Add best ingredient to final list
157
- final_ingredients.append((best_name, best_embedding))
158
-
159
- # Remove ingredient from available ingredients
160
- embed_available = [item for item in embed_available if item[0] != best_name]
161
-
162
- # Extract only ingredient names
163
- return [name for name, _ in final_ingredients]
164
-
165
-
166
- def skip_special_tokens(text, special_tokens):
167
- """Removes special tokens from text"""
168
- for token in special_tokens:
169
- text = text.replace(token, "")
170
- return text
171
-
172
-
173
- def target_postprocessing(texts, special_tokens):
174
- """Post-processes generated text"""
175
- if not isinstance(texts, list):
176
- texts = [texts]
177
-
178
- new_texts = []
179
- for text in texts:
180
- text = skip_special_tokens(text, special_tokens)
181
-
182
- for k, v in tokens_map.items():
183
- text = text.replace(k, v)
184
-
185
- new_texts.append(text)
186
-
187
- return new_texts
188
-
189
-
190
- def validate_recipe_ingredients(recipe_ingredients, expected_ingredients, tolerance=0):
191
- """
192
- Validates if the recipe contains approximately the expected ingredients.
193
-
194
- Args:
195
- recipe_ingredients (list): Ingredients from generated recipe
196
- expected_ingredients (list): Expected ingredients
197
- tolerance (int): Allowed difference in ingredient count
198
-
199
- Returns:
200
- bool: True if recipe is valid, False otherwise
201
- """
202
- # Count non-empty ingredients
203
- recipe_count = len([ing for ing in recipe_ingredients if ing and ing.strip()])
204
- expected_count = len(expected_ingredients)
205
-
206
- # Check if ingredient count is within tolerance
207
- return abs(recipe_count - expected_count) == tolerance
208
-
209
-
210
- def generate_recipe_with_t5(ingredients_list, max_retries=5):
211
- """
212
- Generates a recipe using the T5 recipe generation model with validation.
213
-
214
- Args:
215
- ingredients_list (list): List of ingredients
216
- max_retries (int): Maximum number of retry attempts
217
-
218
- Returns:
219
- dict: A dictionary with title, ingredients, and directions
220
- """
221
- original_ingredients = ingredients_list.copy()
222
-
223
- for attempt in range(max_retries):
224
- try:
225
- # For retries after the first attempt, shuffle the ingredients
226
- if attempt > 0:
227
- current_ingredients = original_ingredients.copy()
228
- random.shuffle(current_ingredients)
229
- print(f"Retry {attempt}: Shuffling ingredients order")
230
- else:
231
- current_ingredients = ingredients_list
232
-
233
- # Format ingredients as a comma-separated string
234
- ingredients_string = ", ".join(current_ingredients)
235
- prefix = "items: "
236
-
237
- # Generation settings
238
- generation_kwargs = {
239
- "max_length": 512,
240
- "min_length": 64,
241
- "do_sample": True,
242
- "top_k": 60,
243
- "top_p": 0.95
244
- }
245
- print(f"Attempt {attempt + 1}: {prefix + ingredients_string}")
246
-
247
- # Tokenize input
248
- inputs = t5_tokenizer(
249
- prefix + ingredients_string,
250
- max_length=256,
251
- padding="max_length",
252
- truncation=True,
253
- return_tensors="jax"
254
- )
255
-
256
- # Generate text
257
- output_ids = t5_model.generate(
258
- input_ids=inputs.input_ids,
259
- attention_mask=inputs.attention_mask,
260
- **generation_kwargs
261
- )
262
-
263
- # Decode and post-process
264
- generated = output_ids.sequences
265
- generated_text = target_postprocessing(
266
- t5_tokenizer.batch_decode(generated, skip_special_tokens=False),
267
- special_tokens
268
- )[0]
269
-
270
- # Parse sections
271
- recipe = {}
272
- sections = generated_text.split("\n")
273
- for section in sections:
274
- section = section.strip()
275
- if section.startswith("title:"):
276
- recipe["title"] = section.replace("title:", "").strip().capitalize()
277
- elif section.startswith("ingredients:"):
278
- ingredients_text = section.replace("ingredients:", "").strip()
279
- recipe["ingredients"] = [item.strip().capitalize() for item in ingredients_text.split("--") if
280
- item.strip()]
281
- elif section.startswith("directions:"):
282
- directions_text = section.replace("directions:", "").strip()
283
- recipe["directions"] = [step.strip().capitalize() for step in directions_text.split("--") if
284
- step.strip()]
285
-
286
- # If title is missing, create one
287
- if "title" not in recipe:
288
- recipe["title"] = f"Recipe with {', '.join(current_ingredients[:3])}"
289
-
290
- # Ensure all sections exist
291
- if "ingredients" not in recipe:
292
- recipe["ingredients"] = current_ingredients
293
- if "directions" not in recipe:
294
- recipe["directions"] = ["No directions generated"]
295
-
296
- # Validate the recipe
297
- if validate_recipe_ingredients(recipe["ingredients"], original_ingredients):
298
- print(f"Success on attempt {attempt + 1}: Recipe has correct number of ingredients")
299
- return recipe
300
- else:
301
- print(
302
- f"Attempt {attempt + 1} failed: Expected {len(original_ingredients)} ingredients, got {len(recipe['ingredients'])}")
303
- if attempt == max_retries - 1:
304
- print("Max retries reached, returning last generated recipe")
305
- return recipe
306
-
307
- except Exception as e:
308
- print(f"Error in recipe generation attempt {attempt + 1}: {str(e)}")
309
- if attempt == max_retries - 1:
310
- return {
311
- "title": f"Recipe with {original_ingredients[0] if original_ingredients else 'ingredients'}",
312
- "ingredients": original_ingredients,
313
- "directions": ["Error generating recipe instructions"]
314
- }
315
-
316
- # Fallback (should not be reached)
317
- return {
318
- "title": f"Recipe with {original_ingredients[0] if original_ingredients else 'ingredients'}",
319
- "ingredients": original_ingredients,
320
- "directions": ["Error generating recipe instructions"]
321
- }
322
-
323
-
324
- @app.route('/generate_recipe', methods=['POST'])
325
- def handle_recipe_request():
326
- """
327
- Processes a recipe generation request with a given list of ingredients.
328
- Uses the intelligent ingredient combination feature.
329
- """
330
- if not request.is_json:
331
- return jsonify({"error": "Request must be JSON"}), 415
332
-
333
- data = request.get_json()
334
-
335
- # Extract required and available ingredients from request
336
- required_ingredients = data.get('required_ingredients', [])
337
- available_ingredients = data.get('available_ingredients', [])
338
-
339
- # For backward compatibility: If only 'ingredients' is specified, treat as required ingredients
340
- if data.get('ingredients') and not required_ingredients:
341
- required_ingredients = data.get('ingredients', [])
342
-
343
- # Maximum number of ingredients (for better recipes)
344
- max_ingredients = data.get('max_ingredients', 7)
345
-
346
- # Maximum retries for recipe generation
347
- max_retries = data.get('max_retries', 5)
348
-
349
- # If no ingredients specified
350
- if not required_ingredients and not available_ingredients:
351
- return jsonify({"error": "No ingredients provided"}), 400
352
-
353
- try:
354
- # Always find best ingredient combination with RecipeBERT
355
- optimized_ingredients = find_best_ingredients(
356
- required_ingredients,
357
- available_ingredients,
358
- max_ingredients
359
- )
360
-
361
- # Generate recipe with optimized ingredients using T5 model with validation
362
- recipe = generate_recipe_with_t5(optimized_ingredients, max_retries)
363
-
364
- # Format for Flutter app consumption - structured format
365
- return jsonify({
366
- 'title': recipe['title'],
367
- 'ingredients': recipe['ingredients'],
368
- 'directions': recipe['directions'],
369
- 'used_ingredients': optimized_ingredients
370
- })
371
-
372
- except Exception as e:
373
- return jsonify({"error": f"Error in recipe generation: {str(e)}"}), 500
374
-
375
-
376
- @app.route('/generate_recipe_smart', methods=['POST'])
377
- def handle_smart_recipe_request():
378
- """
379
- Processes an intelligent recipe generation request.
380
- This endpoint remains for backward compatibility.
381
- """
382
- # Delegate to handle_recipe_request
383
- return handle_recipe_request()
384
-
385
-
386
- if __name__ == '__main__':
387
- app.run(host='0.0.0.0', port=8000, debug=True)
 
1
+ import gradio as gr
2
+ from transformers import FlaxAutoModelForSeq2SeqLM, AutoTokenizer
3
+ from transformers import AutoModel
4
+ import torch
5
+ import numpy as np
6
+ import random
7
+
8
+ # Load models (same as before)
9
+ bert_model_name = "alexdseo/RecipeBERT"
10
+ bert_tokenizer = AutoTokenizer.from_pretrained(bert_model_name)
11
+ bert_model = AutoModel.from_pretrained(bert_model_name)
12
+ bert_model.eval()
13
+
14
+ MODEL_NAME_OR_PATH = "flax-community/t5-recipe-generation"
15
+ t5_tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME_OR_PATH, use_fast=True)
16
+ t5_model = FlaxAutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME_OR_PATH)
17
+
18
+ # ... (all your existing functions remain the same) ...
19
+ # get_embedding, average_embedding, get_cosine_similarity, etc.
20
+
21
+ def generate_recipe_interface(required_ingredients_text, available_ingredients_text, max_ingredients, max_retries):
22
+ """Gradio interface function"""
23
+ try:
24
+ # Parse ingredient inputs
25
+ required_ingredients = [ing.strip() for ing in required_ingredients_text.split(',') if ing.strip()]
26
+ available_ingredients = [ing.strip() for ing in available_ingredients_text.split(',') if ing.strip()]
27
+
28
+ # Find optimal ingredient combination
29
+ optimized_ingredients = find_best_ingredients(
30
+ required_ingredients,
31
+ available_ingredients,
32
+ max_ingredients
33
+ )
34
+
35
+ # Generate recipe
36
+ recipe = generate_recipe_with_t5(optimized_ingredients, max_retries)
37
+
38
+ # Format output
39
+ ingredients_list = '\n'.join([f"• {ing}" for ing in recipe['ingredients']])
40
+ directions_list = '\n'.join([f"{i+1}. {dir}" for i, dir in enumerate(recipe['directions'])])
41
+ used_ingredients = ', '.join(optimized_ingredients)
42
+
43
+ return (
44
+ recipe['title'],
45
+ ingredients_list,
46
+ directions_list,
47
+ used_ingredients
48
+ )
49
+
50
+ except Exception as e:
51
+ return f"Error: {str(e)}", "", "", ""
52
+
53
+ # Create Gradio interface
54
+ with gr.Blocks(title="AI Recipe Generator") as demo:
55
+ gr.Markdown("# 🍳 AI Recipe Generator")
56
+ gr.Markdown("Generate recipes using AI with intelligent ingredient combination!")
57
+
58
+ with gr.Row():
59
+ with gr.Column():
60
+ required_ing = gr.Textbox(
61
+ label="Required Ingredients (comma-separated)",
62
+ placeholder="chicken, rice, onion",
63
+ lines=2
64
+ )
65
+ available_ing = gr.Textbox(
66
+ label="Available Ingredients (comma-separated)",
67
+ placeholder="garlic, tomato, pepper, herbs",
68
+ lines=2
69
+ )
70
+ max_ing = gr.Slider(
71
+ minimum=3,
72
+ maximum=10,
73
+ value=7,
74
+ step=1,
75
+ label="Maximum Ingredients"
76
+ )
77
+ max_retries = gr.Slider(
78
+ minimum=1,
79
+ maximum=10,
80
+ value=5,
81
+ step=1,
82
+ label="Max Retries"
83
+ )
84
+ generate_btn = gr.Button("Generate Recipe", variant="primary")
85
+
86
+ with gr.Column():
87
+ title_output = gr.Textbox(label="Recipe Title", interactive=False)
88
+ ingredients_output = gr.Textbox(label="Ingredients", lines=8, interactive=False)
89
+ directions_output = gr.Textbox(label="Directions", lines=10, interactive=False)
90
+ used_ingredients_output = gr.Textbox(label="Used Ingredients", interactive=False)
91
+
92
+ generate_btn.click(
93
+ fn=generate_recipe_interface,
94
+ inputs=[required_ing, available_ing, max_ing, max_retries],
95
+ outputs=[title_output, ingredients_output, directions_output, used_ingredients_output]
96
+ )
97
+
98
+ # Example
99
+ gr.Examples(
100
+ examples=[
101
+ ["chicken, rice", "onion, garlic, tomato, herbs, pepper", 6, 3],
102
+ ["pasta", "cheese, mushrooms, cream, spinach, garlic", 5, 3],
103
+ ],
104
+ inputs=[required_ing, available_ing, max_ing, max_retries]
105
+ )
106
+
107
+ if __name__ == "__main__":
108
+ demo.launch()