TimInf commited on
Commit
5751919
·
verified ·
1 Parent(s): 0fb7e49

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +387 -0
  2. requirements.txt +7 -0
app.py ADDED
@@ -0,0 +1,387 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from flask import Flask, request, jsonify
2
+ from transformers import FlaxAutoModelForSeq2SeqLM, AutoTokenizer
3
+ from transformers import AutoModel
4
+ import torch
5
+ import numpy as np
6
+ import random
7
+ from flask_cors import CORS
8
+
9
+ app = Flask(__name__)
10
+ CORS(app)
11
+
12
+ # Load RecipeBERT model (for semantic ingredient combination)
13
+ bert_model_name = "alexdseo/RecipeBERT"
14
+ bert_tokenizer = AutoTokenizer.from_pretrained(bert_model_name)
15
+ bert_model = AutoModel.from_pretrained(bert_model_name)
16
+ bert_model.eval()
17
+
18
+ # Load T5 recipe generation model
19
+ MODEL_NAME_OR_PATH = "flax-community/t5-recipe-generation"
20
+ t5_tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME_OR_PATH, use_fast=True)
21
+ t5_model = FlaxAutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME_OR_PATH)
22
+
23
+ # Token mapping for T5 model output processing
24
+ special_tokens = t5_tokenizer.all_special_tokens
25
+ tokens_map = {
26
+ "<sep>": "--",
27
+ "<section>": "\n"
28
+ }
29
+
30
+
31
+ def get_embedding(text):
32
+ """Computes embedding for a text with Mean Pooling over all tokens"""
33
+ inputs = bert_tokenizer(text, return_tensors="pt", truncation=True, padding=True)
34
+ with torch.no_grad():
35
+ outputs = bert_model(**inputs)
36
+
37
+ # Mean Pooling - take average of all token embeddings
38
+ attention_mask = inputs['attention_mask']
39
+ token_embeddings = outputs.last_hidden_state
40
+ input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
41
+ sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1)
42
+ sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
43
+
44
+ return (sum_embeddings / sum_mask).squeeze(0)
45
+
46
+
47
+ def average_embedding(embedding_list):
48
+ """Computes the average of a list of embeddings"""
49
+ tensors = torch.stack([emb for _, emb in embedding_list])
50
+ return tensors.mean(dim=0)
51
+
52
+
53
+ def get_cosine_similarity(vec1, vec2):
54
+ """Computes the cosine similarity between two vectors"""
55
+ if torch.is_tensor(vec1):
56
+ vec1 = vec1.detach().numpy()
57
+ if torch.is_tensor(vec2):
58
+ vec2 = vec2.detach().numpy()
59
+
60
+ # Make sure vectors have the right shape (flatten if necessary)
61
+ vec1 = vec1.flatten()
62
+ vec2 = vec2.flatten()
63
+
64
+ dot_product = np.dot(vec1, vec2)
65
+ norm_a = np.linalg.norm(vec1)
66
+ norm_b = np.linalg.norm(vec2)
67
+
68
+ # Avoid division by zero
69
+ if norm_a == 0 or norm_b == 0:
70
+ return 0
71
+
72
+ return dot_product / (norm_a * norm_b)
73
+
74
+
75
+ def get_combined_scores(query_vector, embedding_list, all_good_embeddings, avg_weight=0.6):
76
+ """Computes combined score considering both similarity to average and individual ingredients"""
77
+ results = []
78
+
79
+ for name, emb in embedding_list:
80
+ # Similarity to average vector
81
+ avg_similarity = get_cosine_similarity(query_vector, emb)
82
+
83
+ # Average similarity to individual ingredients
84
+ individual_similarities = [get_cosine_similarity(good_emb, emb)
85
+ for _, good_emb in all_good_embeddings]
86
+ avg_individual_similarity = sum(individual_similarities) / len(individual_similarities)
87
+
88
+ # Combined score (weighted average)
89
+ combined_score = avg_weight * avg_similarity + (1 - avg_weight) * avg_individual_similarity
90
+
91
+ results.append((name, emb, combined_score))
92
+
93
+ # Sort by combined score (descending)
94
+ results.sort(key=lambda x: x[2], reverse=True)
95
+ return results
96
+
97
+
98
+ def find_best_ingredients(required_ingredients, available_ingredients, max_ingredients=6, avg_weight=0.6):
99
+ """
100
+ Finds the best ingredients based on RecipeBERT embeddings.
101
+
102
+ Args:
103
+ required_ingredients (list): Required ingredients that must be used
104
+ available_ingredients (list): Available ingredients to choose from
105
+ max_ingredients (int): Maximum number of ingredients for the recipe
106
+ avg_weight (float): Weight for average vector
107
+
108
+ Returns:
109
+ list: The optimal combination of ingredients
110
+ """
111
+ # Ensure no duplicates in lists
112
+ required_ingredients = list(set(required_ingredients))
113
+ available_ingredients = list(set([i for i in available_ingredients if i not in required_ingredients]))
114
+
115
+ # Special case: If no required ingredients, randomly select one from available ingredients
116
+ if not required_ingredients and available_ingredients:
117
+ # Randomly select 1 ingredient as starting point
118
+ random_ingredient = random.choice(available_ingredients)
119
+ required_ingredients = [random_ingredient]
120
+ available_ingredients = [i for i in available_ingredients if i != random_ingredient]
121
+ print(f"No required ingredients provided. Randomly selected: {random_ingredient}")
122
+
123
+ # If still no ingredients or already at max capacity
124
+ if not required_ingredients or len(required_ingredients) >= max_ingredients:
125
+ return required_ingredients[:max_ingredients]
126
+
127
+ # If no additional ingredients available
128
+ if not available_ingredients:
129
+ return required_ingredients
130
+
131
+ # Calculate embeddings for all ingredients
132
+ embed_required = [(e, get_embedding(e)) for e in required_ingredients]
133
+ embed_available = [(e, get_embedding(e)) for e in available_ingredients]
134
+
135
+ # Number of ingredients to add
136
+ num_to_add = min(max_ingredients - len(required_ingredients), len(available_ingredients))
137
+
138
+ # Copy required ingredients to final list
139
+ final_ingredients = embed_required.copy()
140
+
141
+ # Add best ingredients
142
+ for _ in range(num_to_add):
143
+ # Calculate average vector of current combination
144
+ avg = average_embedding(final_ingredients)
145
+
146
+ # Calculate combined scores for all candidates
147
+ candidates = get_combined_scores(avg, embed_available, final_ingredients, avg_weight)
148
+
149
+ # If no candidates left, break
150
+ if not candidates:
151
+ break
152
+
153
+ # Choose best ingredient
154
+ best_name, best_embedding, _ = candidates[0]
155
+
156
+ # Add best ingredient to final list
157
+ final_ingredients.append((best_name, best_embedding))
158
+
159
+ # Remove ingredient from available ingredients
160
+ embed_available = [item for item in embed_available if item[0] != best_name]
161
+
162
+ # Extract only ingredient names
163
+ return [name for name, _ in final_ingredients]
164
+
165
+
166
+ def skip_special_tokens(text, special_tokens):
167
+ """Removes special tokens from text"""
168
+ for token in special_tokens:
169
+ text = text.replace(token, "")
170
+ return text
171
+
172
+
173
+ def target_postprocessing(texts, special_tokens):
174
+ """Post-processes generated text"""
175
+ if not isinstance(texts, list):
176
+ texts = [texts]
177
+
178
+ new_texts = []
179
+ for text in texts:
180
+ text = skip_special_tokens(text, special_tokens)
181
+
182
+ for k, v in tokens_map.items():
183
+ text = text.replace(k, v)
184
+
185
+ new_texts.append(text)
186
+
187
+ return new_texts
188
+
189
+
190
+ def validate_recipe_ingredients(recipe_ingredients, expected_ingredients, tolerance=0):
191
+ """
192
+ Validates if the recipe contains approximately the expected ingredients.
193
+
194
+ Args:
195
+ recipe_ingredients (list): Ingredients from generated recipe
196
+ expected_ingredients (list): Expected ingredients
197
+ tolerance (int): Allowed difference in ingredient count
198
+
199
+ Returns:
200
+ bool: True if recipe is valid, False otherwise
201
+ """
202
+ # Count non-empty ingredients
203
+ recipe_count = len([ing for ing in recipe_ingredients if ing and ing.strip()])
204
+ expected_count = len(expected_ingredients)
205
+
206
+ # Check if ingredient count is within tolerance
207
+ return abs(recipe_count - expected_count) == tolerance
208
+
209
+
210
+ def generate_recipe_with_t5(ingredients_list, max_retries=5):
211
+ """
212
+ Generates a recipe using the T5 recipe generation model with validation.
213
+
214
+ Args:
215
+ ingredients_list (list): List of ingredients
216
+ max_retries (int): Maximum number of retry attempts
217
+
218
+ Returns:
219
+ dict: A dictionary with title, ingredients, and directions
220
+ """
221
+ original_ingredients = ingredients_list.copy()
222
+
223
+ for attempt in range(max_retries):
224
+ try:
225
+ # For retries after the first attempt, shuffle the ingredients
226
+ if attempt > 0:
227
+ current_ingredients = original_ingredients.copy()
228
+ random.shuffle(current_ingredients)
229
+ print(f"Retry {attempt}: Shuffling ingredients order")
230
+ else:
231
+ current_ingredients = ingredients_list
232
+
233
+ # Format ingredients as a comma-separated string
234
+ ingredients_string = ", ".join(current_ingredients)
235
+ prefix = "items: "
236
+
237
+ # Generation settings
238
+ generation_kwargs = {
239
+ "max_length": 512,
240
+ "min_length": 64,
241
+ "do_sample": True,
242
+ "top_k": 60,
243
+ "top_p": 0.95
244
+ }
245
+ print(f"Attempt {attempt + 1}: {prefix + ingredients_string}")
246
+
247
+ # Tokenize input
248
+ inputs = t5_tokenizer(
249
+ prefix + ingredients_string,
250
+ max_length=256,
251
+ padding="max_length",
252
+ truncation=True,
253
+ return_tensors="jax"
254
+ )
255
+
256
+ # Generate text
257
+ output_ids = t5_model.generate(
258
+ input_ids=inputs.input_ids,
259
+ attention_mask=inputs.attention_mask,
260
+ **generation_kwargs
261
+ )
262
+
263
+ # Decode and post-process
264
+ generated = output_ids.sequences
265
+ generated_text = target_postprocessing(
266
+ t5_tokenizer.batch_decode(generated, skip_special_tokens=False),
267
+ special_tokens
268
+ )[0]
269
+
270
+ # Parse sections
271
+ recipe = {}
272
+ sections = generated_text.split("\n")
273
+ for section in sections:
274
+ section = section.strip()
275
+ if section.startswith("title:"):
276
+ recipe["title"] = section.replace("title:", "").strip().capitalize()
277
+ elif section.startswith("ingredients:"):
278
+ ingredients_text = section.replace("ingredients:", "").strip()
279
+ recipe["ingredients"] = [item.strip().capitalize() for item in ingredients_text.split("--") if
280
+ item.strip()]
281
+ elif section.startswith("directions:"):
282
+ directions_text = section.replace("directions:", "").strip()
283
+ recipe["directions"] = [step.strip().capitalize() for step in directions_text.split("--") if
284
+ step.strip()]
285
+
286
+ # If title is missing, create one
287
+ if "title" not in recipe:
288
+ recipe["title"] = f"Recipe with {', '.join(current_ingredients[:3])}"
289
+
290
+ # Ensure all sections exist
291
+ if "ingredients" not in recipe:
292
+ recipe["ingredients"] = current_ingredients
293
+ if "directions" not in recipe:
294
+ recipe["directions"] = ["No directions generated"]
295
+
296
+ # Validate the recipe
297
+ if validate_recipe_ingredients(recipe["ingredients"], original_ingredients):
298
+ print(f"Success on attempt {attempt + 1}: Recipe has correct number of ingredients")
299
+ return recipe
300
+ else:
301
+ print(
302
+ f"Attempt {attempt + 1} failed: Expected {len(original_ingredients)} ingredients, got {len(recipe['ingredients'])}")
303
+ if attempt == max_retries - 1:
304
+ print("Max retries reached, returning last generated recipe")
305
+ return recipe
306
+
307
+ except Exception as e:
308
+ print(f"Error in recipe generation attempt {attempt + 1}: {str(e)}")
309
+ if attempt == max_retries - 1:
310
+ return {
311
+ "title": f"Recipe with {original_ingredients[0] if original_ingredients else 'ingredients'}",
312
+ "ingredients": original_ingredients,
313
+ "directions": ["Error generating recipe instructions"]
314
+ }
315
+
316
+ # Fallback (should not be reached)
317
+ return {
318
+ "title": f"Recipe with {original_ingredients[0] if original_ingredients else 'ingredients'}",
319
+ "ingredients": original_ingredients,
320
+ "directions": ["Error generating recipe instructions"]
321
+ }
322
+
323
+
324
+ @app.route('/generate_recipe', methods=['POST'])
325
+ def handle_recipe_request():
326
+ """
327
+ Processes a recipe generation request with a given list of ingredients.
328
+ Uses the intelligent ingredient combination feature.
329
+ """
330
+ if not request.is_json:
331
+ return jsonify({"error": "Request must be JSON"}), 415
332
+
333
+ data = request.get_json()
334
+
335
+ # Extract required and available ingredients from request
336
+ required_ingredients = data.get('required_ingredients', [])
337
+ available_ingredients = data.get('available_ingredients', [])
338
+
339
+ # For backward compatibility: If only 'ingredients' is specified, treat as required ingredients
340
+ if data.get('ingredients') and not required_ingredients:
341
+ required_ingredients = data.get('ingredients', [])
342
+
343
+ # Maximum number of ingredients (for better recipes)
344
+ max_ingredients = data.get('max_ingredients', 7)
345
+
346
+ # Maximum retries for recipe generation
347
+ max_retries = data.get('max_retries', 5)
348
+
349
+ # If no ingredients specified
350
+ if not required_ingredients and not available_ingredients:
351
+ return jsonify({"error": "No ingredients provided"}), 400
352
+
353
+ try:
354
+ # Always find best ingredient combination with RecipeBERT
355
+ optimized_ingredients = find_best_ingredients(
356
+ required_ingredients,
357
+ available_ingredients,
358
+ max_ingredients
359
+ )
360
+
361
+ # Generate recipe with optimized ingredients using T5 model with validation
362
+ recipe = generate_recipe_with_t5(optimized_ingredients, max_retries)
363
+
364
+ # Format for Flutter app consumption - structured format
365
+ return jsonify({
366
+ 'title': recipe['title'],
367
+ 'ingredients': recipe['ingredients'],
368
+ 'directions': recipe['directions'],
369
+ 'used_ingredients': optimized_ingredients
370
+ })
371
+
372
+ except Exception as e:
373
+ return jsonify({"error": f"Error in recipe generation: {str(e)}"}), 500
374
+
375
+
376
+ @app.route('/generate_recipe_smart', methods=['POST'])
377
+ def handle_smart_recipe_request():
378
+ """
379
+ Processes an intelligent recipe generation request.
380
+ This endpoint remains for backward compatibility.
381
+ """
382
+ # Delegate to handle_recipe_request
383
+ return handle_recipe_request()
384
+
385
+
386
+ if __name__ == '__main__':
387
+ app.run(host='0.0.0.0', port=8000, debug=True)
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ Flask
2
+ Flask-Cors
3
+ transformers
4
+ torch
5
+ numpy
6
+ jax # If you use JAX backend for FlaxAutoModelForSeq2SeqLM
7
+ jaxlib # If you use JAX backend for FlaxAutoModelForSeq2SeqLM