Recipe / app.py
TimInf's picture
Update app.py
74d625f verified
raw
history blame
17.8 kB
import gradio as gr
from transformers import FlaxAutoModelForSeq2SeqLM, AutoTokenizer, AutoModel
import torch
import numpy as np
import random
import json
import os
# Disable Gradio analytics for better performance
os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
# Load RecipeBERT model (for semantic ingredient combination)
bert_model_name = "alexdseo/RecipeBERT"
bert_tokenizer = AutoTokenizer.from_pretrained(bert_model_name)
bert_model = AutoModel.from_pretrained(bert_model_name)
bert_model.eval()
# Load T5 recipe generation model
MODEL_NAME_OR_PATH = "flax-community/t5-recipe-generation"
t5_tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME_OR_PATH, use_fast=True)
t5_model = FlaxAutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME_OR_PATH)
# Token mapping for T5 model output processing
special_tokens = t5_tokenizer.all_special_tokens
tokens_map = {
"<sep>": "--",
"<section>": "\n"
}
def get_embedding(text):
"""Computes embedding for a text with Mean Pooling over all tokens"""
inputs = bert_tokenizer(text, return_tensors="pt", truncation=True, padding=True)
with torch.no_grad():
outputs = bert_model(**inputs)
# Mean Pooling - take average of all token embeddings
attention_mask = inputs['attention_mask']
token_embeddings = outputs.last_hidden_state
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1)
sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
return (sum_embeddings / sum_mask).squeeze(0)
def average_embedding(embedding_list):
"""Computes the average of a list of embeddings"""
tensors = torch.stack([emb for _, emb in embedding_list])
return tensors.mean(dim=0)
def get_cosine_similarity(vec1, vec2):
"""Computes the cosine similarity between two vectors"""
if torch.is_tensor(vec1):
vec1 = vec1.detach().numpy()
if torch.is_tensor(vec2):
vec2 = vec2.detach().numpy()
# Make sure vectors have the right shape (flatten if necessary)
vec1 = vec1.flatten()
vec2 = vec2.flatten()
dot_product = np.dot(vec1, vec2)
norm_a = np.linalg.norm(vec1)
norm_b = np.linalg.norm(vec2)
# Avoid division by zero
if norm_a == 0 or norm_b == 0:
return 0
return dot_product / (norm_a * norm_b)
def get_combined_scores(query_vector, embedding_list, all_good_embeddings, avg_weight=0.6):
"""Computes combined score considering both similarity to average and individual ingredients"""
results = []
for name, emb in embedding_list:
# Similarity to average vector
avg_similarity = get_cosine_similarity(query_vector, emb)
# Average similarity to individual ingredients
individual_similarities = [get_cosine_similarity(good_emb, emb)
for _, good_emb in all_good_embeddings]
avg_individual_similarity = sum(individual_similarities) / len(individual_similarities)
# Combined score (weighted average)
combined_score = avg_weight * avg_similarity + (1 - avg_weight) * avg_individual_similarity
results.append((name, emb, combined_score))
# Sort by combined score (descending)
results.sort(key=lambda x: x[2], reverse=True)
return results
def find_best_ingredients(required_ingredients, available_ingredients, max_ingredients=6, avg_weight=0.6):
"""
Finds the best ingredients based on RecipeBERT embeddings.
"""
# Clean and prepare ingredient lists
required_ingredients = [ing.strip() for ing in required_ingredients if ing.strip()]
available_ingredients = [ing.strip() for ing in available_ingredients if ing.strip()]
# Remove duplicates
required_ingredients = list(set(required_ingredients))
available_ingredients = list(set([i for i in available_ingredients if i not in required_ingredients]))
# Special case: If no required ingredients, randomly select one from available ingredients
if not required_ingredients and available_ingredients:
random_ingredient = random.choice(available_ingredients)
required_ingredients = [random_ingredient]
available_ingredients = [i for i in available_ingredients if i != random_ingredient]
# If still no ingredients or already at max capacity
if not required_ingredients or len(required_ingredients) >= max_ingredients:
return required_ingredients[:max_ingredients]
# If no additional ingredients available
if not available_ingredients:
return required_ingredients
# Calculate embeddings for all ingredients
embed_required = [(e, get_embedding(e)) for e in required_ingredients]
embed_available = [(e, get_embedding(e)) for e in available_ingredients]
# Number of ingredients to add
num_to_add = min(max_ingredients - len(required_ingredients), len(available_ingredients))
# Copy required ingredients to final list
final_ingredients = embed_required.copy()
# Add best ingredients
for _ in range(num_to_add):
# Calculate average vector of current combination
avg = average_embedding(final_ingredients)
# Calculate combined scores for all candidates
candidates = get_combined_scores(avg, embed_available, final_ingredients, avg_weight)
# If no candidates left, break
if not candidates:
break
# Choose best ingredient
best_name, best_embedding, _ = candidates[0]
# Add best ingredient to final list
final_ingredients.append((best_name, best_embedding))
# Remove ingredient from available ingredients
embed_available = [item for item in embed_available if item[0] != best_name]
# Extract only ingredient names
return [name for name, _ in final_ingredients]
def skip_special_tokens(text, special_tokens):
"""Removes special tokens from text"""
for token in special_tokens:
text = text.replace(token, "")
return text
def target_postprocessing(texts, special_tokens):
"""Post-processes generated text"""
if not isinstance(texts, list):
texts = [texts]
new_texts = []
for text in texts:
text = skip_special_tokens(text, special_tokens)
for k, v in tokens_map.items():
text = text.replace(k, v)
new_texts.append(text)
return new_texts
def validate_recipe_ingredients(recipe_ingredients, expected_ingredients, tolerance=0):
"""Validates if the recipe contains approximately the expected ingredients."""
recipe_count = len([ing for ing in recipe_ingredients if ing and ing.strip()])
expected_count = len(expected_ingredients)
return abs(recipe_count - expected_count) <= tolerance
def generate_recipe_with_t5(ingredients_list, max_retries=5):
"""Generates a recipe using the T5 recipe generation model with validation."""
original_ingredients = ingredients_list.copy()
for attempt in range(max_retries):
try:
# For retries after the first attempt, shuffle the ingredients
if attempt > 0:
current_ingredients = original_ingredients.copy()
random.shuffle(current_ingredients)
else:
current_ingredients = ingredients_list
# Format ingredients as a comma-separated string
ingredients_string = ", ".join(current_ingredients)
prefix = "items: "
# Generation settings
generation_kwargs = {
"max_length": 512,
"min_length": 64,
"do_sample": True,
"top_k": 60,
"top_p": 0.95
}
# Tokenize input
inputs = t5_tokenizer(
prefix + ingredients_string,
max_length=256,
padding="max_length",
truncation=True,
return_tensors="jax"
)
# Generate text
output_ids = t5_model.generate(
input_ids=inputs.input_ids,
attention_mask=inputs.attention_mask,
**generation_kwargs
)
# Decode and post-process
generated = output_ids.sequences
generated_text = target_postprocessing(
t5_tokenizer.batch_decode(generated, skip_special_tokens=False),
special_tokens
)[0]
# Parse sections
recipe = {}
sections = generated_text.split("\n")
for section in sections:
section = section.strip()
if section.startswith("title:"):
recipe["title"] = section.replace("title:", "").strip().capitalize()
elif section.startswith("ingredients:"):
ingredients_text = section.replace("ingredients:", "").strip()
recipe["ingredients"] = [item.strip().capitalize() for item in ingredients_text.split("--") if
item.strip()]
elif section.startswith("directions:"):
directions_text = section.replace("directions:", "").strip()
recipe["directions"] = [step.strip().capitalize() for step in directions_text.split("--") if
step.strip()]
# If title is missing, create one
if "title" not in recipe:
recipe["title"] = f"Recipe with {', '.join(current_ingredients[:3])}"
# Ensure all sections exist
if "ingredients" not in recipe:
recipe["ingredients"] = current_ingredients
if "directions" not in recipe:
recipe["directions"] = ["No directions generated"]
# Validate the recipe
if validate_recipe_ingredients(recipe["ingredients"], original_ingredients, tolerance=1):
return recipe
else:
if attempt == max_retries - 1:
return recipe
except Exception as e:
if attempt == max_retries - 1:
return {
"title": f"Recipe with {original_ingredients[0] if original_ingredients else 'ingredients'}",
"ingredients": original_ingredients,
"directions": ["Error generating recipe instructions"]
}
# Fallback
return {
"title": f"Recipe with {original_ingredients[0] if original_ingredients else 'ingredients'}",
"ingredients": original_ingredients,
"directions": ["Error generating recipe instructions"]
}
def generate_recipe_interface(required_ingredients_text, available_ingredients_text, max_ingredients):
"""Main interface function for Gradio"""
try:
# Parse ingredient inputs
required_ingredients = []
available_ingredients = []
if required_ingredients_text:
required_ingredients = [ing.strip() for ing in required_ingredients_text.split(',') if ing.strip()]
if available_ingredients_text:
available_ingredients = [ing.strip() for ing in available_ingredients_text.split(',') if ing.strip()]
# Validate inputs
if not required_ingredients and not available_ingredients:
return "❌ **Error:** Please provide at least some ingredients!", "", "", ""
# Find best ingredient combination
optimized_ingredients = find_best_ingredients(
required_ingredients,
available_ingredients,
max_ingredients
)
# Generate recipe
recipe = generate_recipe_with_t5(optimized_ingredients)
# Format output
title = f"🍽️ **{recipe['title']}**"
ingredients_formatted = "## πŸ“‹ Ingredients:\n" + "\n".join([f"β€’ {ing}" for ing in recipe['ingredients']])
directions_formatted = "## πŸ‘¨β€πŸ³ Instructions:\n" + "\n".join(
[f"{i + 1}. {step}" for i, step in enumerate(recipe['directions'])])
used_ingredients = "## βœ… Used Ingredients:\n" + ", ".join(optimized_ingredients)
return title, ingredients_formatted, directions_formatted, used_ingredients
except Exception as e:
return f"❌ **Error:** {str(e)}", "", "", ""
def generate_recipe_api(required_ingredients_text, available_ingredients_text, max_ingredients):
"""API-compatible function that returns JSON format"""
try:
# Parse ingredient inputs
required_ingredients = []
available_ingredients = []
if required_ingredients_text:
required_ingredients = [ing.strip() for ing in required_ingredients_text.split(',') if ing.strip()]
if available_ingredients_text:
available_ingredients = [ing.strip() for ing in available_ingredients_text.split(',') if ing.strip()]
# Validate inputs
if not required_ingredients and not available_ingredients:
return json.dumps({"error": "No ingredients provided"}, indent=2)
# Find best ingredient combination
optimized_ingredients = find_best_ingredients(
required_ingredients,
available_ingredients,
max_ingredients
)
# Generate recipe
recipe = generate_recipe_with_t5(optimized_ingredients)
# Format for API response
api_response = {
'title': recipe['title'],
'ingredients': recipe['ingredients'],
'directions': recipe['directions'],
'used_ingredients': optimized_ingredients
}
return json.dumps(api_response, indent=2, ensure_ascii=False)
except Exception as e:
return json.dumps({"error": f"Error in recipe generation: {str(e)}"}, indent=2)
# Create Gradio interface
with gr.Blocks(title="🍳 AI Recipe Generator", theme=gr.themes.Soft()) as demo:
gr.Markdown("""
# 🍳 AI Recipe Generator
Generate delicious recipes using AI! This tool uses **RecipeBERT** to find the best ingredient combinations and **T5** to generate complete recipes.
## How to use:
1. **Required Ingredients:** Enter ingredients you must use (comma-separated)
2. **Available Ingredients:** Enter additional ingredients you have available (comma-separated)
3. **Max Ingredients:** Set the maximum number of ingredients for your recipe
4. Click **Generate Recipe** to create your personalized recipe!
""")
with gr.Tab("🍽️ Recipe Generator"):
with gr.Row():
with gr.Column():
required_ingredients = gr.Textbox(
label="🎯 Required Ingredients",
placeholder="chicken, rice, onions",
info="Ingredients that must be included in the recipe (comma-separated)"
)
available_ingredients = gr.Textbox(
label="πŸ₯• Available Ingredients",
placeholder="garlic, tomatoes, basil, cheese",
info="Additional ingredients you have available (comma-separated)"
)
max_ingredients = gr.Slider(
minimum=3, maximum=12, value=7, step=1,
label="πŸ“Š Maximum Ingredients",
info="Maximum number of ingredients to use in the recipe"
)
generate_btn = gr.Button("πŸš€ Generate Recipe", variant="primary", size="lg")
with gr.Column():
recipe_title = gr.Markdown()
used_ingredients = gr.Markdown()
with gr.Row():
with gr.Column():
recipe_ingredients = gr.Markdown()
with gr.Column():
recipe_directions = gr.Markdown()
with gr.Tab("πŸ”Œ API Format"):
gr.Markdown("""
## API Response Format
This tab shows the response in JSON format, compatible with your Flutter app.
""")
with gr.Row():
with gr.Column():
api_required = gr.Textbox(
label="Required Ingredients",
placeholder="chicken, rice, onions"
)
api_available = gr.Textbox(
label="Available Ingredients",
placeholder="garlic, tomatoes, basil"
)
api_max = gr.Slider(
minimum=3, maximum=12, value=7, step=1,
label="Max Ingredients"
)
api_generate_btn = gr.Button("Generate JSON", variant="secondary")
with gr.Column():
api_output = gr.Code(language="json", label="API Response")
# Event handlers - WICHTIG: Diese erstellen die API-Endpunkte!
generate_btn.click(
fn=generate_recipe_interface,
inputs=[required_ingredients, available_ingredients, max_ingredients],
outputs=[recipe_title, recipe_ingredients, recipe_directions, used_ingredients]
)
api_generate_btn.click(
fn=generate_recipe_api,
inputs=[api_required, api_available, api_max],
outputs=[api_output]
)
# Example inputs
gr.Examples(
examples=[
["chicken, rice", "onions, garlic, tomatoes, basil", 6],
["eggs, flour", "milk, sugar, vanilla, butter", 7],
["salmon", "lemon, dill, potatoes, asparagus", 5],
["", "beef, potatoes, carrots, onions, garlic", 6]
],
inputs=[required_ingredients, available_ingredients, max_ingredients]
)
# Launch the application
if __name__ == "__main__":
demo.launch(
server_name="0.0.0.0",
server_port=7860,
show_error=True
)