| import gradio as gr
|
| from transformers import FlaxAutoModelForSeq2SeqLM, AutoTokenizer, AutoModel
|
| import torch
|
| import numpy as np
|
| import random
|
| import json
|
|
|
|
|
| bert_model_name = "alexdseo/RecipeBERT"
|
| bert_tokenizer = AutoTokenizer.from_pretrained(bert_model_name)
|
| bert_model = AutoModel.from_pretrained(bert_model_name)
|
| bert_model.eval()
|
|
|
|
|
| MODEL_NAME_OR_PATH = "flax-community/t5-recipe-generation"
|
| t5_tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME_OR_PATH, use_fast=True)
|
| t5_model = FlaxAutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME_OR_PATH)
|
|
|
|
|
| special_tokens = t5_tokenizer.all_special_tokens
|
| tokens_map = {
|
| "<sep>": "--",
|
| "<section>": "\n"
|
| }
|
|
|
|
|
| def get_embedding(text):
|
| """Computes embedding for a text with Mean Pooling over all tokens"""
|
| inputs = bert_tokenizer(text, return_tensors="pt", truncation=True, padding=True)
|
| with torch.no_grad():
|
| outputs = bert_model(**inputs)
|
|
|
|
|
| attention_mask = inputs['attention_mask']
|
| token_embeddings = outputs.last_hidden_state
|
| input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
|
| sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1)
|
| sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
|
|
|
| return (sum_embeddings / sum_mask).squeeze(0)
|
|
|
|
|
| def average_embedding(embedding_list):
|
| """Computes the average of a list of embeddings"""
|
| tensors = torch.stack([emb for _, emb in embedding_list])
|
| return tensors.mean(dim=0)
|
|
|
|
|
| def get_cosine_similarity(vec1, vec2):
|
| """Computes the cosine similarity between two vectors"""
|
| if torch.is_tensor(vec1):
|
| vec1 = vec1.detach().numpy()
|
| if torch.is_tensor(vec2):
|
| vec2 = vec2.detach().numpy()
|
|
|
|
|
| vec1 = vec1.flatten()
|
| vec2 = vec2.flatten()
|
|
|
| dot_product = np.dot(vec1, vec2)
|
| norm_a = np.linalg.norm(vec1)
|
| norm_b = np.linalg.norm(vec2)
|
|
|
|
|
| if norm_a == 0 or norm_b == 0:
|
| return 0
|
|
|
| return dot_product / (norm_a * norm_b)
|
|
|
|
|
| def get_combined_scores(query_vector, embedding_list, all_good_embeddings, avg_weight=0.6):
|
| """Computes combined score considering both similarity to average and individual ingredients"""
|
| results = []
|
|
|
| for name, emb in embedding_list:
|
|
|
| avg_similarity = get_cosine_similarity(query_vector, emb)
|
|
|
|
|
| individual_similarities = [get_cosine_similarity(good_emb, emb)
|
| for _, good_emb in all_good_embeddings]
|
| avg_individual_similarity = sum(individual_similarities) / len(individual_similarities)
|
|
|
|
|
| combined_score = avg_weight * avg_similarity + (1 - avg_weight) * avg_individual_similarity
|
|
|
| results.append((name, emb, combined_score))
|
|
|
|
|
| results.sort(key=lambda x: x[2], reverse=True)
|
| return results
|
|
|
|
|
| def find_best_ingredients(required_ingredients, available_ingredients, max_ingredients=6, avg_weight=0.6):
|
| """
|
| Finds the best ingredients based on RecipeBERT embeddings.
|
| """
|
|
|
| required_ingredients = [ing.strip() for ing in required_ingredients if ing.strip()]
|
| available_ingredients = [ing.strip() for ing in available_ingredients if ing.strip()]
|
|
|
|
|
| required_ingredients = list(set(required_ingredients))
|
| available_ingredients = list(set([i for i in available_ingredients if i not in required_ingredients]))
|
|
|
|
|
| if not required_ingredients and available_ingredients:
|
| random_ingredient = random.choice(available_ingredients)
|
| required_ingredients = [random_ingredient]
|
| available_ingredients = [i for i in available_ingredients if i != random_ingredient]
|
|
|
|
|
| if not required_ingredients or len(required_ingredients) >= max_ingredients:
|
| return required_ingredients[:max_ingredients]
|
|
|
|
|
| if not available_ingredients:
|
| return required_ingredients
|
|
|
|
|
| embed_required = [(e, get_embedding(e)) for e in required_ingredients]
|
| embed_available = [(e, get_embedding(e)) for e in available_ingredients]
|
|
|
|
|
| num_to_add = min(max_ingredients - len(required_ingredients), len(available_ingredients))
|
|
|
|
|
| final_ingredients = embed_required.copy()
|
|
|
|
|
| for _ in range(num_to_add):
|
|
|
| avg = average_embedding(final_ingredients)
|
|
|
|
|
| candidates = get_combined_scores(avg, embed_available, final_ingredients, avg_weight)
|
|
|
|
|
| if not candidates:
|
| break
|
|
|
|
|
| best_name, best_embedding, _ = candidates[0]
|
|
|
|
|
| final_ingredients.append((best_name, best_embedding))
|
|
|
|
|
| embed_available = [item for item in embed_available if item[0] != best_name]
|
|
|
|
|
| return [name for name, _ in final_ingredients]
|
|
|
|
|
| def skip_special_tokens(text, special_tokens):
|
| """Removes special tokens from text"""
|
| for token in special_tokens:
|
| text = text.replace(token, "")
|
| return text
|
|
|
|
|
| def target_postprocessing(texts, special_tokens):
|
| """Post-processes generated text"""
|
| if not isinstance(texts, list):
|
| texts = [texts]
|
|
|
| new_texts = []
|
| for text in texts:
|
| text = skip_special_tokens(text, special_tokens)
|
|
|
| for k, v in tokens_map.items():
|
| text = text.replace(k, v)
|
|
|
| new_texts.append(text)
|
|
|
| return new_texts
|
|
|
|
|
| def validate_recipe_ingredients(recipe_ingredients, expected_ingredients, tolerance=0):
|
| """Validates if the recipe contains approximately the expected ingredients."""
|
| recipe_count = len([ing for ing in recipe_ingredients if ing and ing.strip()])
|
| expected_count = len(expected_ingredients)
|
| return abs(recipe_count - expected_count) <= tolerance
|
|
|
|
|
| def generate_recipe_with_t5(ingredients_list, max_retries=5):
|
| """Generates a recipe using the T5 recipe generation model with validation."""
|
| original_ingredients = ingredients_list.copy()
|
|
|
| for attempt in range(max_retries):
|
| try:
|
|
|
| if attempt > 0:
|
| current_ingredients = original_ingredients.copy()
|
| random.shuffle(current_ingredients)
|
| else:
|
| current_ingredients = ingredients_list
|
|
|
|
|
| ingredients_string = ", ".join(current_ingredients)
|
| prefix = "items: "
|
|
|
|
|
| generation_kwargs = {
|
| "max_length": 512,
|
| "min_length": 64,
|
| "do_sample": True,
|
| "top_k": 60,
|
| "top_p": 0.95
|
| }
|
|
|
|
|
| inputs = t5_tokenizer(
|
| prefix + ingredients_string,
|
| max_length=256,
|
| padding="max_length",
|
| truncation=True,
|
| return_tensors="jax"
|
| )
|
|
|
|
|
| output_ids = t5_model.generate(
|
| input_ids=inputs.input_ids,
|
| attention_mask=inputs.attention_mask,
|
| **generation_kwargs
|
| )
|
|
|
|
|
| generated = output_ids.sequences
|
| generated_text = target_postprocessing(
|
| t5_tokenizer.batch_decode(generated, skip_special_tokens=False),
|
| special_tokens
|
| )[0]
|
|
|
|
|
| recipe = {}
|
| sections = generated_text.split("\n")
|
| for section in sections:
|
| section = section.strip()
|
| if section.startswith("title:"):
|
| recipe["title"] = section.replace("title:", "").strip().capitalize()
|
| elif section.startswith("ingredients:"):
|
| ingredients_text = section.replace("ingredients:", "").strip()
|
| recipe["ingredients"] = [item.strip().capitalize() for item in ingredients_text.split("--") if
|
| item.strip()]
|
| elif section.startswith("directions:"):
|
| directions_text = section.replace("directions:", "").strip()
|
| recipe["directions"] = [step.strip().capitalize() for step in directions_text.split("--") if
|
| step.strip()]
|
|
|
|
|
| if "title" not in recipe:
|
| recipe["title"] = f"Recipe with {', '.join(current_ingredients[:3])}"
|
|
|
|
|
| if "ingredients" not in recipe:
|
| recipe["ingredients"] = current_ingredients
|
| if "directions" not in recipe:
|
| recipe["directions"] = ["No directions generated"]
|
|
|
|
|
| if validate_recipe_ingredients(recipe["ingredients"], original_ingredients, tolerance=1):
|
| return recipe
|
| else:
|
| if attempt == max_retries - 1:
|
| return recipe
|
|
|
| except Exception as e:
|
| if attempt == max_retries - 1:
|
| return {
|
| "title": f"Recipe with {original_ingredients[0] if original_ingredients else 'ingredients'}",
|
| "ingredients": original_ingredients,
|
| "directions": ["Error generating recipe instructions"]
|
| }
|
|
|
|
|
| return {
|
| "title": f"Recipe with {original_ingredients[0] if original_ingredients else 'ingredients'}",
|
| "ingredients": original_ingredients,
|
| "directions": ["Error generating recipe instructions"]
|
| }
|
|
|
|
|
| def generate_recipe_interface(required_ingredients_text, available_ingredients_text, max_ingredients):
|
| """Main interface function for Gradio"""
|
| try:
|
|
|
| required_ingredients = []
|
| available_ingredients = []
|
|
|
| if required_ingredients_text:
|
| required_ingredients = [ing.strip() for ing in required_ingredients_text.split(',') if ing.strip()]
|
|
|
| if available_ingredients_text:
|
| available_ingredients = [ing.strip() for ing in available_ingredients_text.split(',') if ing.strip()]
|
|
|
|
|
| if not required_ingredients and not available_ingredients:
|
| return "β **Error:** Please provide at least some ingredients!", "", "", ""
|
|
|
|
|
| optimized_ingredients = find_best_ingredients(
|
| required_ingredients,
|
| available_ingredients,
|
| max_ingredients
|
| )
|
|
|
|
|
| recipe = generate_recipe_with_t5(optimized_ingredients)
|
|
|
|
|
| title = f"π½οΈ **{recipe['title']}**"
|
|
|
| ingredients_formatted = "## π Ingredients:\n" + "\n".join([f"β’ {ing}" for ing in recipe['ingredients']])
|
|
|
| directions_formatted = "## π¨βπ³ Instructions:\n" + "\n".join(
|
| [f"{i + 1}. {step}" for i, step in enumerate(recipe['directions'])])
|
|
|
| used_ingredients = "## β
Used Ingredients:\n" + ", ".join(optimized_ingredients)
|
|
|
| return title, ingredients_formatted, directions_formatted, used_ingredients
|
|
|
| except Exception as e:
|
| return f"β **Error:** {str(e)}", "", "", ""
|
|
|
|
|
| def generate_recipe_api(required_ingredients_text, available_ingredients_text, max_ingredients):
|
| """API-compatible function that returns JSON format"""
|
| try:
|
|
|
| required_ingredients = []
|
| available_ingredients = []
|
|
|
| if required_ingredients_text:
|
| required_ingredients = [ing.strip() for ing in required_ingredients_text.split(',') if ing.strip()]
|
|
|
| if available_ingredients_text:
|
| available_ingredients = [ing.strip() for ing in available_ingredients_text.split(',') if ing.strip()]
|
|
|
|
|
| if not required_ingredients and not available_ingredients:
|
| return json.dumps({"error": "No ingredients provided"}, indent=2)
|
|
|
|
|
| optimized_ingredients = find_best_ingredients(
|
| required_ingredients,
|
| available_ingredients,
|
| max_ingredients
|
| )
|
|
|
|
|
| recipe = generate_recipe_with_t5(optimized_ingredients)
|
|
|
|
|
| api_response = {
|
| 'title': recipe['title'],
|
| 'ingredients': recipe['ingredients'],
|
| 'directions': recipe['directions'],
|
| 'used_ingredients': optimized_ingredients
|
| }
|
|
|
| return json.dumps(api_response, indent=2, ensure_ascii=False)
|
|
|
| except Exception as e:
|
| return json.dumps({"error": f"Error in recipe generation: {str(e)}"}, indent=2)
|
|
|
|
|
|
|
| with gr.Blocks(title="π³ AI Recipe Generator", theme=gr.themes.Soft()) as demo:
|
| gr.Markdown("""
|
| # π³ AI Recipe Generator
|
|
|
| Generate delicious recipes using AI! This tool uses **RecipeBERT** to find the best ingredient combinations and **T5** to generate complete recipes.
|
|
|
| ## How to use:
|
| 1. **Required Ingredients:** Enter ingredients you must use (comma-separated)
|
| 2. **Available Ingredients:** Enter additional ingredients you have available (comma-separated)
|
| 3. **Max Ingredients:** Set the maximum number of ingredients for your recipe
|
| 4. Click **Generate Recipe** to create your personalized recipe!
|
| """)
|
|
|
| with gr.Tab("π½οΈ Recipe Generator"):
|
| with gr.Row():
|
| with gr.Column():
|
| required_ingredients = gr.Textbox(
|
| label="π― Required Ingredients",
|
| placeholder="chicken, rice, onions",
|
| info="Ingredients that must be included in the recipe (comma-separated)"
|
| )
|
| available_ingredients = gr.Textbox(
|
| label="π₯ Available Ingredients",
|
| placeholder="garlic, tomatoes, basil, cheese",
|
| info="Additional ingredients you have available (comma-separated)"
|
| )
|
| max_ingredients = gr.Slider(
|
| minimum=3, maximum=12, value=7, step=1,
|
| label="π Maximum Ingredients",
|
| info="Maximum number of ingredients to use in the recipe"
|
| )
|
| generate_btn = gr.Button("π Generate Recipe", variant="primary", size="lg")
|
|
|
| with gr.Column():
|
| recipe_title = gr.Markdown()
|
| used_ingredients = gr.Markdown()
|
|
|
| with gr.Row():
|
| with gr.Column():
|
| recipe_ingredients = gr.Markdown()
|
| with gr.Column():
|
| recipe_directions = gr.Markdown()
|
|
|
| with gr.Tab("π API Format"):
|
| gr.Markdown("""
|
| ## API Response Format
|
| This tab shows the response in JSON format, compatible with your Flutter app.
|
| """)
|
|
|
| with gr.Row():
|
| with gr.Column():
|
| api_required = gr.Textbox(
|
| label="Required Ingredients",
|
| placeholder="chicken, rice, onions"
|
| )
|
| api_available = gr.Textbox(
|
| label="Available Ingredients",
|
| placeholder="garlic, tomatoes, basil"
|
| )
|
| api_max = gr.Slider(
|
| minimum=3, maximum=12, value=7, step=1,
|
| label="Max Ingredients"
|
| )
|
| api_generate_btn = gr.Button("Generate JSON", variant="secondary")
|
|
|
| with gr.Column():
|
| api_output = gr.Code(language="json", label="API Response")
|
|
|
|
|
| generate_btn.click(
|
| fn=generate_recipe_interface,
|
| inputs=[required_ingredients, available_ingredients, max_ingredients],
|
| outputs=[recipe_title, recipe_ingredients, recipe_directions, used_ingredients]
|
| )
|
|
|
| api_generate_btn.click(
|
| fn=generate_recipe_api,
|
| inputs=[api_required, api_available, api_max],
|
| outputs=[api_output]
|
| )
|
|
|
|
|
| gr.Examples(
|
| examples=[
|
| ["chicken, rice", "onions, garlic, tomatoes, basil", 6],
|
| ["eggs, flour", "milk, sugar, vanilla, butter", 7],
|
| ["salmon", "lemon, dill, potatoes, asparagus", 5],
|
| ["", "beef, potatoes, carrots, onions, garlic", 6]
|
| ],
|
| inputs=[required_ingredients, available_ingredients, max_ingredients]
|
| )
|
|
|
| if __name__ == "__main__":
|
| demo.launch() |