|
|
import gradio as gr |
|
|
from transformers import FlaxAutoModelForSeq2SeqLM, AutoTokenizer, AutoModel |
|
|
import torch |
|
|
import numpy as np |
|
|
import random |
|
|
import json |
|
|
import os |
|
|
|
|
|
|
|
|
os.environ["GRADIO_ANALYTICS_ENABLED"] = "False" |
|
|
|
|
|
|
|
|
bert_model_name = "alexdseo/RecipeBERT" |
|
|
bert_tokenizer = AutoTokenizer.from_pretrained(bert_model_name) |
|
|
bert_model = AutoModel.from_pretrained(bert_model_name) |
|
|
bert_model.eval() |
|
|
|
|
|
|
|
|
MODEL_NAME_OR_PATH = "flax-community/t5-recipe-generation" |
|
|
t5_tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME_OR_PATH, use_fast=True) |
|
|
t5_model = FlaxAutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME_OR_PATH) |
|
|
|
|
|
|
|
|
special_tokens = t5_tokenizer.all_special_tokens |
|
|
tokens_map = { |
|
|
"<sep>": "--", |
|
|
"<section>": "\n" |
|
|
} |
|
|
|
|
|
|
|
|
def get_embedding(text): |
|
|
"""Computes embedding for a text with Mean Pooling over all tokens""" |
|
|
inputs = bert_tokenizer(text, return_tensors="pt", truncation=True, padding=True) |
|
|
with torch.no_grad(): |
|
|
outputs = bert_model(**inputs) |
|
|
|
|
|
|
|
|
attention_mask = inputs['attention_mask'] |
|
|
token_embeddings = outputs.last_hidden_state |
|
|
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() |
|
|
sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1) |
|
|
sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9) |
|
|
|
|
|
return (sum_embeddings / sum_mask).squeeze(0) |
|
|
|
|
|
|
|
|
def average_embedding(embedding_list): |
|
|
"""Computes the average of a list of embeddings""" |
|
|
tensors = torch.stack([emb for _, emb in embedding_list]) |
|
|
return tensors.mean(dim=0) |
|
|
|
|
|
|
|
|
def get_cosine_similarity(vec1, vec2): |
|
|
"""Computes the cosine similarity between two vectors""" |
|
|
if torch.is_tensor(vec1): |
|
|
vec1 = vec1.detach().numpy() |
|
|
if torch.is_tensor(vec2): |
|
|
vec2 = vec2.detach().numpy() |
|
|
|
|
|
|
|
|
vec1 = vec1.flatten() |
|
|
vec2 = vec2.flatten() |
|
|
|
|
|
dot_product = np.dot(vec1, vec2) |
|
|
norm_a = np.linalg.norm(vec1) |
|
|
norm_b = np.linalg.norm(vec2) |
|
|
|
|
|
|
|
|
if norm_a == 0 or norm_b == 0: |
|
|
return 0 |
|
|
|
|
|
return dot_product / (norm_a * norm_b) |
|
|
|
|
|
|
|
|
def get_combined_scores(query_vector, embedding_list, all_good_embeddings, avg_weight=0.6): |
|
|
"""Computes combined score considering both similarity to average and individual ingredients""" |
|
|
results = [] |
|
|
|
|
|
for name, emb in embedding_list: |
|
|
|
|
|
avg_similarity = get_cosine_similarity(query_vector, emb) |
|
|
|
|
|
|
|
|
individual_similarities = [get_cosine_similarity(good_emb, emb) |
|
|
for _, good_emb in all_good_embeddings] |
|
|
avg_individual_similarity = sum(individual_similarities) / len(individual_similarities) |
|
|
|
|
|
|
|
|
combined_score = avg_weight * avg_similarity + (1 - avg_weight) * avg_individual_similarity |
|
|
|
|
|
results.append((name, emb, combined_score)) |
|
|
|
|
|
|
|
|
results.sort(key=lambda x: x[2], reverse=True) |
|
|
return results |
|
|
|
|
|
|
|
|
def find_best_ingredients(required_ingredients, available_ingredients, max_ingredients=6, avg_weight=0.6): |
|
|
""" |
|
|
Finds the best ingredients based on RecipeBERT embeddings. |
|
|
""" |
|
|
|
|
|
required_ingredients = [ing.strip() for ing in required_ingredients if ing.strip()] |
|
|
available_ingredients = [ing.strip() for ing in available_ingredients if ing.strip()] |
|
|
|
|
|
|
|
|
required_ingredients = list(set(required_ingredients)) |
|
|
available_ingredients = list(set([i for i in available_ingredients if i not in required_ingredients])) |
|
|
|
|
|
|
|
|
if not required_ingredients and available_ingredients: |
|
|
random_ingredient = random.choice(available_ingredients) |
|
|
required_ingredients = [random_ingredient] |
|
|
available_ingredients = [i for i in available_ingredients if i != random_ingredient] |
|
|
|
|
|
|
|
|
if not required_ingredients or len(required_ingredients) >= max_ingredients: |
|
|
return required_ingredients[:max_ingredients] |
|
|
|
|
|
|
|
|
if not available_ingredients: |
|
|
return required_ingredients |
|
|
|
|
|
|
|
|
embed_required = [(e, get_embedding(e)) for e in required_ingredients] |
|
|
embed_available = [(e, get_embedding(e)) for e in available_ingredients] |
|
|
|
|
|
|
|
|
num_to_add = min(max_ingredients - len(required_ingredients), len(available_ingredients)) |
|
|
|
|
|
|
|
|
final_ingredients = embed_required.copy() |
|
|
|
|
|
|
|
|
for _ in range(num_to_add): |
|
|
|
|
|
avg = average_embedding(final_ingredients) |
|
|
|
|
|
|
|
|
candidates = get_combined_scores(avg, embed_available, final_ingredients, avg_weight) |
|
|
|
|
|
|
|
|
if not candidates: |
|
|
break |
|
|
|
|
|
|
|
|
best_name, best_embedding, _ = candidates[0] |
|
|
|
|
|
|
|
|
final_ingredients.append((best_name, best_embedding)) |
|
|
|
|
|
|
|
|
embed_available = [item for item in embed_available if item[0] != best_name] |
|
|
|
|
|
|
|
|
return [name for name, _ in final_ingredients] |
|
|
|
|
|
|
|
|
def skip_special_tokens(text, special_tokens): |
|
|
"""Removes special tokens from text""" |
|
|
for token in special_tokens: |
|
|
text = text.replace(token, "") |
|
|
return text |
|
|
|
|
|
|
|
|
def target_postprocessing(texts, special_tokens): |
|
|
"""Post-processes generated text""" |
|
|
if not isinstance(texts, list): |
|
|
texts = [texts] |
|
|
|
|
|
new_texts = [] |
|
|
for text in texts: |
|
|
text = skip_special_tokens(text, special_tokens) |
|
|
|
|
|
for k, v in tokens_map.items(): |
|
|
text = text.replace(k, v) |
|
|
|
|
|
new_texts.append(text) |
|
|
|
|
|
return new_texts |
|
|
|
|
|
|
|
|
def validate_recipe_ingredients(recipe_ingredients, expected_ingredients, tolerance=0): |
|
|
"""Validates if the recipe contains approximately the expected ingredients.""" |
|
|
recipe_count = len([ing for ing in recipe_ingredients if ing and ing.strip()]) |
|
|
expected_count = len(expected_ingredients) |
|
|
return abs(recipe_count - expected_count) <= tolerance |
|
|
|
|
|
|
|
|
def generate_recipe_with_t5(ingredients_list, max_retries=5): |
|
|
"""Generates a recipe using the T5 recipe generation model with validation.""" |
|
|
original_ingredients = ingredients_list.copy() |
|
|
|
|
|
for attempt in range(max_retries): |
|
|
try: |
|
|
|
|
|
if attempt > 0: |
|
|
current_ingredients = original_ingredients.copy() |
|
|
random.shuffle(current_ingredients) |
|
|
else: |
|
|
current_ingredients = ingredients_list |
|
|
|
|
|
|
|
|
ingredients_string = ", ".join(current_ingredients) |
|
|
prefix = "items: " |
|
|
|
|
|
|
|
|
generation_kwargs = { |
|
|
"max_length": 512, |
|
|
"min_length": 64, |
|
|
"do_sample": True, |
|
|
"top_k": 60, |
|
|
"top_p": 0.95 |
|
|
} |
|
|
|
|
|
|
|
|
inputs = t5_tokenizer( |
|
|
prefix + ingredients_string, |
|
|
max_length=256, |
|
|
padding="max_length", |
|
|
truncation=True, |
|
|
return_tensors="jax" |
|
|
) |
|
|
|
|
|
|
|
|
output_ids = t5_model.generate( |
|
|
input_ids=inputs.input_ids, |
|
|
attention_mask=inputs.attention_mask, |
|
|
**generation_kwargs |
|
|
) |
|
|
|
|
|
|
|
|
generated = output_ids.sequences |
|
|
generated_text = target_postprocessing( |
|
|
t5_tokenizer.batch_decode(generated, skip_special_tokens=False), |
|
|
special_tokens |
|
|
)[0] |
|
|
|
|
|
|
|
|
recipe = {} |
|
|
sections = generated_text.split("\n") |
|
|
for section in sections: |
|
|
section = section.strip() |
|
|
if section.startswith("title:"): |
|
|
recipe["title"] = section.replace("title:", "").strip().capitalize() |
|
|
elif section.startswith("ingredients:"): |
|
|
ingredients_text = section.replace("ingredients:", "").strip() |
|
|
recipe["ingredients"] = [item.strip().capitalize() for item in ingredients_text.split("--") if |
|
|
item.strip()] |
|
|
elif section.startswith("directions:"): |
|
|
directions_text = section.replace("directions:", "").strip() |
|
|
recipe["directions"] = [step.strip().capitalize() for step in directions_text.split("--") if |
|
|
step.strip()] |
|
|
|
|
|
|
|
|
if "title" not in recipe: |
|
|
recipe["title"] = f"Recipe with {', '.join(current_ingredients[:3])}" |
|
|
|
|
|
|
|
|
if "ingredients" not in recipe: |
|
|
recipe["ingredients"] = current_ingredients |
|
|
if "directions" not in recipe: |
|
|
recipe["directions"] = ["No directions generated"] |
|
|
|
|
|
|
|
|
if validate_recipe_ingredients(recipe["ingredients"], original_ingredients, tolerance=1): |
|
|
return recipe |
|
|
else: |
|
|
if attempt == max_retries - 1: |
|
|
return recipe |
|
|
|
|
|
except Exception as e: |
|
|
if attempt == max_retries - 1: |
|
|
return { |
|
|
"title": f"Recipe with {original_ingredients[0] if original_ingredients else 'ingredients'}", |
|
|
"ingredients": original_ingredients, |
|
|
"directions": ["Error generating recipe instructions"] |
|
|
} |
|
|
|
|
|
|
|
|
return { |
|
|
"title": f"Recipe with {original_ingredients[0] if original_ingredients else 'ingredients'}", |
|
|
"ingredients": original_ingredients, |
|
|
"directions": ["Error generating recipe instructions"] |
|
|
} |
|
|
|
|
|
|
|
|
def generate_recipe_interface(required_ingredients_text, available_ingredients_text, max_ingredients): |
|
|
"""Main interface function for Gradio""" |
|
|
try: |
|
|
|
|
|
required_ingredients = [] |
|
|
available_ingredients = [] |
|
|
|
|
|
if required_ingredients_text: |
|
|
required_ingredients = [ing.strip() for ing in required_ingredients_text.split(',') if ing.strip()] |
|
|
|
|
|
if available_ingredients_text: |
|
|
available_ingredients = [ing.strip() for ing in available_ingredients_text.split(',') if ing.strip()] |
|
|
|
|
|
|
|
|
if not required_ingredients and not available_ingredients: |
|
|
return "β **Error:** Please provide at least some ingredients!", "", "", "" |
|
|
|
|
|
|
|
|
optimized_ingredients = find_best_ingredients( |
|
|
required_ingredients, |
|
|
available_ingredients, |
|
|
max_ingredients |
|
|
) |
|
|
|
|
|
|
|
|
recipe = generate_recipe_with_t5(optimized_ingredients) |
|
|
|
|
|
|
|
|
title = f"π½οΈ **{recipe['title']}**" |
|
|
|
|
|
ingredients_formatted = "## π Ingredients:\n" + "\n".join([f"β’ {ing}" for ing in recipe['ingredients']]) |
|
|
|
|
|
directions_formatted = "## π¨βπ³ Instructions:\n" + "\n".join( |
|
|
[f"{i + 1}. {step}" for i, step in enumerate(recipe['directions'])]) |
|
|
|
|
|
used_ingredients = "## β
Used Ingredients:\n" + ", ".join(optimized_ingredients) |
|
|
|
|
|
return title, ingredients_formatted, directions_formatted, used_ingredients |
|
|
|
|
|
except Exception as e: |
|
|
return f"β **Error:** {str(e)}", "", "", "" |
|
|
|
|
|
|
|
|
def generate_recipe_api(required_ingredients_text, available_ingredients_text, max_ingredients): |
|
|
"""API-compatible function that returns JSON format""" |
|
|
try: |
|
|
|
|
|
required_ingredients = [] |
|
|
available_ingredients = [] |
|
|
|
|
|
if required_ingredients_text: |
|
|
required_ingredients = [ing.strip() for ing in required_ingredients_text.split(',') if ing.strip()] |
|
|
|
|
|
if available_ingredients_text: |
|
|
available_ingredients = [ing.strip() for ing in available_ingredients_text.split(',') if ing.strip()] |
|
|
|
|
|
|
|
|
if not required_ingredients and not available_ingredients: |
|
|
return json.dumps({"error": "No ingredients provided"}, indent=2) |
|
|
|
|
|
|
|
|
optimized_ingredients = find_best_ingredients( |
|
|
required_ingredients, |
|
|
available_ingredients, |
|
|
max_ingredients |
|
|
) |
|
|
|
|
|
|
|
|
recipe = generate_recipe_with_t5(optimized_ingredients) |
|
|
|
|
|
|
|
|
api_response = { |
|
|
'title': recipe['title'], |
|
|
'ingredients': recipe['ingredients'], |
|
|
'directions': recipe['directions'], |
|
|
'used_ingredients': optimized_ingredients |
|
|
} |
|
|
|
|
|
return json.dumps(api_response, indent=2, ensure_ascii=False) |
|
|
|
|
|
except Exception as e: |
|
|
return json.dumps({"error": f"Error in recipe generation: {str(e)}"}, indent=2) |
|
|
|
|
|
|
|
|
|
|
|
with gr.Blocks(title="π³ AI Recipe Generator", theme=gr.themes.Soft()) as demo: |
|
|
gr.Markdown(""" |
|
|
# π³ AI Recipe Generator |
|
|
|
|
|
Generate delicious recipes using AI! This tool uses **RecipeBERT** to find the best ingredient combinations and **T5** to generate complete recipes. |
|
|
|
|
|
## How to use: |
|
|
1. **Required Ingredients:** Enter ingredients you must use (comma-separated) |
|
|
2. **Available Ingredients:** Enter additional ingredients you have available (comma-separated) |
|
|
3. **Max Ingredients:** Set the maximum number of ingredients for your recipe |
|
|
4. Click **Generate Recipe** to create your personalized recipe! |
|
|
""") |
|
|
|
|
|
with gr.Tab("π½οΈ Recipe Generator"): |
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
required_ingredients = gr.Textbox( |
|
|
label="π― Required Ingredients", |
|
|
placeholder="chicken, rice, onions", |
|
|
info="Ingredients that must be included in the recipe (comma-separated)" |
|
|
) |
|
|
available_ingredients = gr.Textbox( |
|
|
label="π₯ Available Ingredients", |
|
|
placeholder="garlic, tomatoes, basil, cheese", |
|
|
info="Additional ingredients you have available (comma-separated)" |
|
|
) |
|
|
max_ingredients = gr.Slider( |
|
|
minimum=3, maximum=12, value=7, step=1, |
|
|
label="π Maximum Ingredients", |
|
|
info="Maximum number of ingredients to use in the recipe" |
|
|
) |
|
|
generate_btn = gr.Button("π Generate Recipe", variant="primary", size="lg") |
|
|
|
|
|
with gr.Column(): |
|
|
recipe_title = gr.Markdown() |
|
|
used_ingredients = gr.Markdown() |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
recipe_ingredients = gr.Markdown() |
|
|
with gr.Column(): |
|
|
recipe_directions = gr.Markdown() |
|
|
|
|
|
with gr.Tab("π API Format"): |
|
|
gr.Markdown(""" |
|
|
## API Response Format |
|
|
This tab shows the response in JSON format, compatible with your Flutter app. |
|
|
""") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
api_required = gr.Textbox( |
|
|
label="Required Ingredients", |
|
|
placeholder="chicken, rice, onions" |
|
|
) |
|
|
api_available = gr.Textbox( |
|
|
label="Available Ingredients", |
|
|
placeholder="garlic, tomatoes, basil" |
|
|
) |
|
|
api_max = gr.Slider( |
|
|
minimum=3, maximum=12, value=7, step=1, |
|
|
label="Max Ingredients" |
|
|
) |
|
|
api_generate_btn = gr.Button("Generate JSON", variant="secondary") |
|
|
|
|
|
with gr.Column(): |
|
|
api_output = gr.Code(language="json", label="API Response") |
|
|
|
|
|
|
|
|
generate_btn.click( |
|
|
fn=generate_recipe_interface, |
|
|
inputs=[required_ingredients, available_ingredients, max_ingredients], |
|
|
outputs=[recipe_title, recipe_ingredients, recipe_directions, used_ingredients] |
|
|
) |
|
|
|
|
|
api_generate_btn.click( |
|
|
fn=generate_recipe_api, |
|
|
inputs=[api_required, api_available, api_max], |
|
|
outputs=[api_output] |
|
|
) |
|
|
|
|
|
|
|
|
gr.Examples( |
|
|
examples=[ |
|
|
["chicken, rice", "onions, garlic, tomatoes, basil", 6], |
|
|
["eggs, flour", "milk, sugar, vanilla, butter", 7], |
|
|
["salmon", "lemon, dill, potatoes, asparagus", 5], |
|
|
["", "beef, potatoes, carrots, onions, garlic", 6] |
|
|
], |
|
|
inputs=[required_ingredients, available_ingredients, max_ingredients] |
|
|
) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch( |
|
|
server_name="0.0.0.0", |
|
|
server_port=7860, |
|
|
show_error=True |
|
|
) |