|
|
import gradio as gr |
|
|
from transformers import FlaxAutoModelForSeq2SeqLM, AutoTokenizer, AutoModel |
|
|
import torch |
|
|
import numpy as np |
|
|
import random |
|
|
import json |
|
|
|
|
|
|
|
|
bert_model_name = "alexdseo/RecipeBERT" |
|
|
bert_tokenizer = AutoTokenizer.from_pretrained(bert_model_name) |
|
|
bert_model = AutoModel.from_pretrained(bert_model_name) |
|
|
bert_model.eval() |
|
|
|
|
|
MODEL_NAME_OR_PATH = "flax-community/t5-recipe-generation" |
|
|
t5_tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME_OR_PATH, use_fast=True) |
|
|
t5_model = FlaxAutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME_OR_PATH) |
|
|
|
|
|
|
|
|
special_tokens = t5_tokenizer.all_special_tokens |
|
|
tokens_map = { |
|
|
"<sep>": "--", |
|
|
"<section>": "\n" |
|
|
} |
|
|
|
|
|
def get_embedding(text): |
|
|
"""Computes embedding for a text with Mean Pooling over all tokens""" |
|
|
inputs = bert_tokenizer(text, return_tensors="pt", truncation=True, padding=True) |
|
|
with torch.no_grad(): |
|
|
outputs = bert_model(**inputs) |
|
|
|
|
|
|
|
|
attention_mask = inputs['attention_mask'] |
|
|
token_embeddings = outputs.last_hidden_state |
|
|
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() |
|
|
sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1) |
|
|
sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9) |
|
|
|
|
|
return (sum_embeddings / sum_mask).squeeze(0) |
|
|
|
|
|
def average_embedding(embedding_list): |
|
|
"""Computes the average of a list of embeddings""" |
|
|
tensors = torch.stack([emb for _, emb in embedding_list]) |
|
|
return tensors.mean(dim=0) |
|
|
|
|
|
def get_cosine_similarity(vec1, vec2): |
|
|
"""Computes the cosine similarity between two vectors""" |
|
|
if torch.is_tensor(vec1): |
|
|
vec1 = vec1.detach().numpy() |
|
|
if torch.is_tensor(vec2): |
|
|
vec2 = vec2.detach().numpy() |
|
|
|
|
|
|
|
|
vec1 = vec1.flatten() |
|
|
vec2 = vec2.flatten() |
|
|
|
|
|
dot_product = np.dot(vec1, vec2) |
|
|
norm_a = np.linalg.norm(vec1) |
|
|
norm_b = np.linalg.norm(vec2) |
|
|
|
|
|
|
|
|
if norm_a == 0 or norm_b == 0: |
|
|
return 0 |
|
|
|
|
|
return dot_product / (norm_a * norm_b) |
|
|
|
|
|
def get_combined_scores(query_vector, embedding_list, all_good_embeddings, avg_weight=0.6): |
|
|
"""Computes combined score considering both similarity to average and individual ingredients""" |
|
|
results = [] |
|
|
|
|
|
for name, emb in embedding_list: |
|
|
|
|
|
avg_similarity = get_cosine_similarity(query_vector, emb) |
|
|
|
|
|
|
|
|
individual_similarities = [get_cosine_similarity(good_emb, emb) |
|
|
for _, good_emb in all_good_embeddings] |
|
|
avg_individual_similarity = sum(individual_similarities) / len(individual_similarities) |
|
|
|
|
|
|
|
|
combined_score = avg_weight * avg_similarity + (1 - avg_weight) * avg_individual_similarity |
|
|
|
|
|
results.append((name, emb, combined_score)) |
|
|
|
|
|
|
|
|
results.sort(key=lambda x: x[2], reverse=True) |
|
|
return results |
|
|
|
|
|
def find_best_ingredients(required_ingredients, available_ingredients, max_ingredients=6, avg_weight=0.6): |
|
|
""" |
|
|
Finds the best ingredients based on RecipeBERT embeddings. |
|
|
""" |
|
|
|
|
|
required_ingredients = list(set(required_ingredients)) |
|
|
available_ingredients = list(set([i for i in available_ingredients if i not in required_ingredients])) |
|
|
|
|
|
|
|
|
if not required_ingredients and available_ingredients: |
|
|
random_ingredient = random.choice(available_ingredients) |
|
|
required_ingredients = [random_ingredient] |
|
|
available_ingredients = [i for i in available_ingredients if i != random_ingredient] |
|
|
|
|
|
|
|
|
if not required_ingredients or len(required_ingredients) >= max_ingredients: |
|
|
return required_ingredients[:max_ingredients] |
|
|
|
|
|
|
|
|
if not available_ingredients: |
|
|
return required_ingredients |
|
|
|
|
|
|
|
|
embed_required = [(e, get_embedding(e)) for e in required_ingredients] |
|
|
embed_available = [(e, get_embedding(e)) for e in available_ingredients] |
|
|
|
|
|
|
|
|
num_to_add = min(max_ingredients - len(required_ingredients), len(available_ingredients)) |
|
|
|
|
|
|
|
|
final_ingredients = embed_required.copy() |
|
|
|
|
|
|
|
|
for _ in range(num_to_add): |
|
|
|
|
|
avg = average_embedding(final_ingredients) |
|
|
|
|
|
|
|
|
candidates = get_combined_scores(avg, embed_available, final_ingredients, avg_weight) |
|
|
|
|
|
|
|
|
if not candidates: |
|
|
break |
|
|
|
|
|
|
|
|
best_name, best_embedding, _ = candidates[0] |
|
|
|
|
|
|
|
|
final_ingredients.append((best_name, best_embedding)) |
|
|
|
|
|
|
|
|
embed_available = [item for item in embed_available if item[0] != best_name] |
|
|
|
|
|
|
|
|
return [name for name, _ in final_ingredients] |
|
|
|
|
|
def skip_special_tokens(text, special_tokens): |
|
|
"""Removes special tokens from text""" |
|
|
for token in special_tokens: |
|
|
text = text.replace(token, "") |
|
|
return text |
|
|
|
|
|
def target_postprocessing(texts, special_tokens): |
|
|
"""Post-processes generated text""" |
|
|
if not isinstance(texts, list): |
|
|
texts = [texts] |
|
|
|
|
|
new_texts = [] |
|
|
for text in texts: |
|
|
text = skip_special_tokens(text, special_tokens) |
|
|
|
|
|
for k, v in tokens_map.items(): |
|
|
text = text.replace(k, v) |
|
|
|
|
|
new_texts.append(text) |
|
|
|
|
|
return new_texts |
|
|
|
|
|
def validate_recipe_ingredients(recipe_ingredients, expected_ingredients, tolerance=0): |
|
|
"""Validates if the recipe contains approximately the expected ingredients.""" |
|
|
recipe_count = len([ing for ing in recipe_ingredients if ing and ing.strip()]) |
|
|
expected_count = len(expected_ingredients) |
|
|
return abs(recipe_count - expected_count) == tolerance |
|
|
|
|
|
def generate_recipe_with_t5(ingredients_list, max_retries=5): |
|
|
"""Generates a recipe using the T5 recipe generation model with validation.""" |
|
|
original_ingredients = ingredients_list.copy() |
|
|
|
|
|
for attempt in range(max_retries): |
|
|
try: |
|
|
|
|
|
if attempt > 0: |
|
|
current_ingredients = original_ingredients.copy() |
|
|
random.shuffle(current_ingredients) |
|
|
else: |
|
|
current_ingredients = ingredients_list |
|
|
|
|
|
|
|
|
ingredients_string = ", ".join(current_ingredients) |
|
|
prefix = "items: " |
|
|
|
|
|
|
|
|
generation_kwargs = { |
|
|
"max_length": 512, |
|
|
"min_length": 64, |
|
|
"do_sample": True, |
|
|
"top_k": 60, |
|
|
"top_p": 0.95 |
|
|
} |
|
|
|
|
|
|
|
|
inputs = t5_tokenizer( |
|
|
prefix + ingredients_string, |
|
|
max_length=256, |
|
|
padding="max_length", |
|
|
truncation=True, |
|
|
return_tensors="jax" |
|
|
) |
|
|
|
|
|
|
|
|
output_ids = t5_model.generate( |
|
|
input_ids=inputs.input_ids, |
|
|
attention_mask=inputs.attention_mask, |
|
|
**generation_kwargs |
|
|
) |
|
|
|
|
|
|
|
|
generated = output_ids.sequences |
|
|
generated_text = target_postprocessing( |
|
|
t5_tokenizer.batch_decode(generated, skip_special_tokens=False), |
|
|
special_tokens |
|
|
)[0] |
|
|
|
|
|
|
|
|
recipe = {} |
|
|
sections = generated_text.split("\n") |
|
|
for section in sections: |
|
|
section = section.strip() |
|
|
if section.startswith("title:"): |
|
|
recipe["title"] = section.replace("title:", "").strip().capitalize() |
|
|
elif section.startswith("ingredients:"): |
|
|
ingredients_text = section.replace("ingredients:", "").strip() |
|
|
recipe["ingredients"] = [item.strip().capitalize() for item in ingredients_text.split("--") if item.strip()] |
|
|
elif section.startswith("directions:"): |
|
|
directions_text = section.replace("directions:", "").strip() |
|
|
recipe["directions"] = [step.strip().capitalize() for step in directions_text.split("--") if step.strip()] |
|
|
|
|
|
|
|
|
if "title" not in recipe: |
|
|
recipe["title"] = f"Recipe with {', '.join(current_ingredients[:3])}" |
|
|
|
|
|
|
|
|
if "ingredients" not in recipe: |
|
|
recipe["ingredients"] = current_ingredients |
|
|
if "directions" not in recipe: |
|
|
recipe["directions"] = ["No directions generated"] |
|
|
|
|
|
|
|
|
if validate_recipe_ingredients(recipe["ingredients"], original_ingredients): |
|
|
return recipe |
|
|
else: |
|
|
if attempt == max_retries - 1: |
|
|
return recipe |
|
|
|
|
|
except Exception as e: |
|
|
if attempt == max_retries - 1: |
|
|
return { |
|
|
"title": f"Recipe with {original_ingredients[0] if original_ingredients else 'ingredients'}", |
|
|
"ingredients": original_ingredients, |
|
|
"directions": ["Error generating recipe instructions"] |
|
|
} |
|
|
|
|
|
|
|
|
return { |
|
|
"title": f"Recipe with {original_ingredients[0] if original_ingredients else 'ingredients'}", |
|
|
"ingredients": original_ingredients, |
|
|
"directions": ["Error generating recipe instructions"] |
|
|
} |
|
|
|
|
|
def flutter_api_generate_recipe(ingredients_data): |
|
|
""" |
|
|
Flutter-friendly API function that processes JSON input |
|
|
and returns structured JSON output matching your original Flask API |
|
|
""" |
|
|
try: |
|
|
|
|
|
if isinstance(ingredients_data, str): |
|
|
data = json.loads(ingredients_data) |
|
|
else: |
|
|
data = ingredients_data |
|
|
|
|
|
|
|
|
required_ingredients = data.get('required_ingredients', []) |
|
|
available_ingredients = data.get('available_ingredients', []) |
|
|
|
|
|
|
|
|
if data.get('ingredients') and not required_ingredients: |
|
|
required_ingredients = data.get('ingredients', []) |
|
|
|
|
|
max_ingredients = data.get('max_ingredients', 7) |
|
|
max_retries = data.get('max_retries', 5) |
|
|
|
|
|
if not required_ingredients and not available_ingredients: |
|
|
return json.dumps({"error": "No ingredients provided"}) |
|
|
|
|
|
|
|
|
optimized_ingredients = find_best_ingredients( |
|
|
required_ingredients, |
|
|
available_ingredients, |
|
|
max_ingredients |
|
|
) |
|
|
|
|
|
|
|
|
recipe = generate_recipe_with_t5(optimized_ingredients, max_retries) |
|
|
|
|
|
|
|
|
result = { |
|
|
'title': recipe['title'], |
|
|
'ingredients': recipe['ingredients'], |
|
|
'directions': recipe['directions'], |
|
|
'used_ingredients': optimized_ingredients |
|
|
} |
|
|
|
|
|
return json.dumps(result) |
|
|
|
|
|
except Exception as e: |
|
|
return json.dumps({"error": f"Error in recipe generation: {str(e)}"}) |
|
|
|
|
|
def gradio_ui_generate_recipe(required_ingredients_text, available_ingredients_text, max_ingredients, max_retries): |
|
|
"""Gradio UI function for web interface""" |
|
|
try: |
|
|
|
|
|
required_ingredients = [ing.strip() for ing in required_ingredients_text.split(',') if ing.strip()] |
|
|
available_ingredients = [ing.strip() for ing in available_ingredients_text.split(',') if ing.strip()] |
|
|
|
|
|
|
|
|
data = { |
|
|
'required_ingredients': required_ingredients, |
|
|
'available_ingredients': available_ingredients, |
|
|
'max_ingredients': max_ingredients, |
|
|
'max_retries': max_retries |
|
|
} |
|
|
|
|
|
|
|
|
result_json = flutter_api_generate_recipe(data) |
|
|
result = json.loads(result_json) |
|
|
|
|
|
if 'error' in result: |
|
|
return result['error'], "", "", "" |
|
|
|
|
|
|
|
|
ingredients_list = '\n'.join([f"• {ing}" for ing in result['ingredients']]) |
|
|
directions_list = '\n'.join([f"{i+1}. {dir}" for i, dir in enumerate(result['directions'])]) |
|
|
used_ingredients = ', '.join(result['used_ingredients']) |
|
|
|
|
|
return ( |
|
|
result['title'], |
|
|
ingredients_list, |
|
|
directions_list, |
|
|
used_ingredients |
|
|
) |
|
|
|
|
|
except Exception as e: |
|
|
return f"Error: {str(e)}", "", "", "" |
|
|
|
|
|
|
|
|
with gr.Blocks(title="AI Recipe Generator") as demo: |
|
|
gr.Markdown("# 🍳 AI Recipe Generator") |
|
|
gr.Markdown("Generate recipes using AI with intelligent ingredient combination!") |
|
|
|
|
|
with gr.Tab("Web Interface"): |
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
required_ing = gr.Textbox( |
|
|
label="Required Ingredients (comma-separated)", |
|
|
placeholder="chicken, rice, onion", |
|
|
lines=2 |
|
|
) |
|
|
available_ing = gr.Textbox( |
|
|
label="Available Ingredients (comma-separated)", |
|
|
placeholder="garlic, tomato, pepper, herbs", |
|
|
lines=2 |
|
|
) |
|
|
max_ing = gr.Slider(3, 10, value=7, step=1, label="Maximum Ingredients") |
|
|
max_retries = gr.Slider(1, 10, value=5, step=1, label="Max Retries") |
|
|
generate_btn = gr.Button("Generate Recipe", variant="primary") |
|
|
|
|
|
with gr.Column(): |
|
|
title_output = gr.Textbox(label="Recipe Title", interactive=False) |
|
|
ingredients_output = gr.Textbox(label="Ingredients", lines=8, interactive=False) |
|
|
directions_output = gr.Textbox(label="Directions", lines=10, interactive=False) |
|
|
used_ingredients_output = gr.Textbox(label="Used Ingredients", interactive=False) |
|
|
|
|
|
generate_btn.click( |
|
|
fn=gradio_ui_generate_recipe, |
|
|
inputs=[required_ing, available_ing, max_ing, max_retries], |
|
|
outputs=[title_output, ingredients_output, directions_output, used_ingredients_output] |
|
|
) |
|
|
|
|
|
with gr.Tab("API Testing"): |
|
|
gr.Markdown("### Test the Flutter API") |
|
|
gr.Markdown("This tab uses the same function that Flutter apps will call via API") |
|
|
|
|
|
api_input = gr.Textbox( |
|
|
label="JSON Input (Flutter API Format)", |
|
|
placeholder='{"required_ingredients": ["chicken", "rice"], "available_ingredients": ["onion", "garlic"], "max_ingredients": 6}', |
|
|
lines=4 |
|
|
) |
|
|
api_output = gr.Textbox(label="JSON Output", lines=15, interactive=False) |
|
|
api_test_btn = gr.Button("Test API", variant="secondary") |
|
|
|
|
|
api_test_btn.click( |
|
|
fn=flutter_api_generate_recipe, |
|
|
inputs=[api_input], |
|
|
outputs=[api_output] |
|
|
) |
|
|
|
|
|
gr.Examples( |
|
|
examples=[ |
|
|
['{"required_ingredients": ["chicken", "rice"], "available_ingredients": ["onion", "garlic", "tomato"], "max_ingredients": 6}'], |
|
|
['{"ingredients": ["pasta"], "available_ingredients": ["cheese", "mushrooms", "cream"], "max_ingredients": 5}'] |
|
|
], |
|
|
inputs=[api_input] |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch() |