import gradio as gr from transformers import FlaxAutoModelForSeq2SeqLM, AutoTokenizer, AutoModel import torch import numpy as np import random import json import os # Disable Gradio analytics for better performance os.environ["GRADIO_ANALYTICS_ENABLED"] = "False" # Load RecipeBERT model (for semantic ingredient combination) bert_model_name = "alexdseo/RecipeBERT" bert_tokenizer = AutoTokenizer.from_pretrained(bert_model_name) bert_model = AutoModel.from_pretrained(bert_model_name) bert_model.eval() # Load T5 recipe generation model MODEL_NAME_OR_PATH = "flax-community/t5-recipe-generation" t5_tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME_OR_PATH, use_fast=True) t5_model = FlaxAutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME_OR_PATH) # Token mapping for T5 model output processing special_tokens = t5_tokenizer.all_special_tokens tokens_map = { "": "--", "
": "\n" } def get_embedding(text): """Computes embedding for a text with Mean Pooling over all tokens""" inputs = bert_tokenizer(text, return_tensors="pt", truncation=True, padding=True) with torch.no_grad(): outputs = bert_model(**inputs) # Mean Pooling - take average of all token embeddings attention_mask = inputs['attention_mask'] token_embeddings = outputs.last_hidden_state input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1) sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9) return (sum_embeddings / sum_mask).squeeze(0) def average_embedding(embedding_list): """Computes the average of a list of embeddings""" tensors = torch.stack([emb for _, emb in embedding_list]) return tensors.mean(dim=0) def get_cosine_similarity(vec1, vec2): """Computes the cosine similarity between two vectors""" if torch.is_tensor(vec1): vec1 = vec1.detach().numpy() if torch.is_tensor(vec2): vec2 = vec2.detach().numpy() # Make sure vectors have the right shape (flatten if necessary) vec1 = vec1.flatten() vec2 = vec2.flatten() dot_product = np.dot(vec1, vec2) norm_a = np.linalg.norm(vec1) norm_b = np.linalg.norm(vec2) # Avoid division by zero if norm_a == 0 or norm_b == 0: return 0 return dot_product / (norm_a * norm_b) def get_combined_scores(query_vector, embedding_list, all_good_embeddings, avg_weight=0.6): """Computes combined score considering both similarity to average and individual ingredients""" results = [] for name, emb in embedding_list: # Similarity to average vector avg_similarity = get_cosine_similarity(query_vector, emb) # Average similarity to individual ingredients individual_similarities = [get_cosine_similarity(good_emb, emb) for _, good_emb in all_good_embeddings] avg_individual_similarity = sum(individual_similarities) / len(individual_similarities) # Combined score (weighted average) combined_score = avg_weight * avg_similarity + (1 - avg_weight) * avg_individual_similarity results.append((name, emb, combined_score)) # Sort by combined score (descending) results.sort(key=lambda x: x[2], reverse=True) return results def find_best_ingredients(required_ingredients, available_ingredients, max_ingredients=6, avg_weight=0.6): """ Finds the best ingredients based on RecipeBERT embeddings. """ # Clean and prepare ingredient lists required_ingredients = [ing.strip() for ing in required_ingredients if ing.strip()] available_ingredients = [ing.strip() for ing in available_ingredients if ing.strip()] # Remove duplicates required_ingredients = list(set(required_ingredients)) available_ingredients = list(set([i for i in available_ingredients if i not in required_ingredients])) # Special case: If no required ingredients, randomly select one from available ingredients if not required_ingredients and available_ingredients: random_ingredient = random.choice(available_ingredients) required_ingredients = [random_ingredient] available_ingredients = [i for i in available_ingredients if i != random_ingredient] # If still no ingredients or already at max capacity if not required_ingredients or len(required_ingredients) >= max_ingredients: return required_ingredients[:max_ingredients] # If no additional ingredients available if not available_ingredients: return required_ingredients # Calculate embeddings for all ingredients embed_required = [(e, get_embedding(e)) for e in required_ingredients] embed_available = [(e, get_embedding(e)) for e in available_ingredients] # Number of ingredients to add num_to_add = min(max_ingredients - len(required_ingredients), len(available_ingredients)) # Copy required ingredients to final list final_ingredients = embed_required.copy() # Add best ingredients for _ in range(num_to_add): # Calculate average vector of current combination avg = average_embedding(final_ingredients) # Calculate combined scores for all candidates candidates = get_combined_scores(avg, embed_available, final_ingredients, avg_weight) # If no candidates left, break if not candidates: break # Choose best ingredient best_name, best_embedding, _ = candidates[0] # Add best ingredient to final list final_ingredients.append((best_name, best_embedding)) # Remove ingredient from available ingredients embed_available = [item for item in embed_available if item[0] != best_name] # Extract only ingredient names return [name for name, _ in final_ingredients] def skip_special_tokens(text, special_tokens): """Removes special tokens from text""" for token in special_tokens: text = text.replace(token, "") return text def target_postprocessing(texts, special_tokens): """Post-processes generated text""" if not isinstance(texts, list): texts = [texts] new_texts = [] for text in texts: text = skip_special_tokens(text, special_tokens) for k, v in tokens_map.items(): text = text.replace(k, v) new_texts.append(text) return new_texts def validate_recipe_ingredients(recipe_ingredients, expected_ingredients, tolerance=0): """Validates if the recipe contains approximately the expected ingredients.""" recipe_count = len([ing for ing in recipe_ingredients if ing and ing.strip()]) expected_count = len(expected_ingredients) return abs(recipe_count - expected_count) <= tolerance def generate_recipe_with_t5(ingredients_list, max_retries=5): """Generates a recipe using the T5 recipe generation model with validation.""" original_ingredients = ingredients_list.copy() for attempt in range(max_retries): try: # For retries after the first attempt, shuffle the ingredients if attempt > 0: current_ingredients = original_ingredients.copy() random.shuffle(current_ingredients) else: current_ingredients = ingredients_list # Format ingredients as a comma-separated string ingredients_string = ", ".join(current_ingredients) prefix = "items: " # Generation settings generation_kwargs = { "max_length": 512, "min_length": 64, "do_sample": True, "top_k": 60, "top_p": 0.95 } # Tokenize input inputs = t5_tokenizer( prefix + ingredients_string, max_length=256, padding="max_length", truncation=True, return_tensors="jax" ) # Generate text output_ids = t5_model.generate( input_ids=inputs.input_ids, attention_mask=inputs.attention_mask, **generation_kwargs ) # Decode and post-process generated = output_ids.sequences generated_text = target_postprocessing( t5_tokenizer.batch_decode(generated, skip_special_tokens=False), special_tokens )[0] # Parse sections recipe = {} sections = generated_text.split("\n") for section in sections: section = section.strip() if section.startswith("title:"): recipe["title"] = section.replace("title:", "").strip().capitalize() elif section.startswith("ingredients:"): ingredients_text = section.replace("ingredients:", "").strip() recipe["ingredients"] = [item.strip().capitalize() for item in ingredients_text.split("--") if item.strip()] elif section.startswith("directions:"): directions_text = section.replace("directions:", "").strip() recipe["directions"] = [step.strip().capitalize() for step in directions_text.split("--") if step.strip()] # If title is missing, create one if "title" not in recipe: recipe["title"] = f"Recipe with {', '.join(current_ingredients[:3])}" # Ensure all sections exist if "ingredients" not in recipe: recipe["ingredients"] = current_ingredients if "directions" not in recipe: recipe["directions"] = ["No directions generated"] # Validate the recipe if validate_recipe_ingredients(recipe["ingredients"], original_ingredients, tolerance=1): return recipe else: if attempt == max_retries - 1: return recipe except Exception as e: if attempt == max_retries - 1: return { "title": f"Recipe with {original_ingredients[0] if original_ingredients else 'ingredients'}", "ingredients": original_ingredients, "directions": ["Error generating recipe instructions"] } # Fallback return { "title": f"Recipe with {original_ingredients[0] if original_ingredients else 'ingredients'}", "ingredients": original_ingredients, "directions": ["Error generating recipe instructions"] } def generate_recipe_interface(required_ingredients_text, available_ingredients_text, max_ingredients): """Main interface function for Gradio""" try: # Parse ingredient inputs required_ingredients = [] available_ingredients = [] if required_ingredients_text: required_ingredients = [ing.strip() for ing in required_ingredients_text.split(',') if ing.strip()] if available_ingredients_text: available_ingredients = [ing.strip() for ing in available_ingredients_text.split(',') if ing.strip()] # Validate inputs if not required_ingredients and not available_ingredients: return "❌ **Error:** Please provide at least some ingredients!", "", "", "" # Find best ingredient combination optimized_ingredients = find_best_ingredients( required_ingredients, available_ingredients, max_ingredients ) # Generate recipe recipe = generate_recipe_with_t5(optimized_ingredients) # Format output title = f"🍽️ **{recipe['title']}**" ingredients_formatted = "## 📋 Ingredients:\n" + "\n".join([f"• {ing}" for ing in recipe['ingredients']]) directions_formatted = "## 👨‍🍳 Instructions:\n" + "\n".join( [f"{i + 1}. {step}" for i, step in enumerate(recipe['directions'])]) used_ingredients = "## ✅ Used Ingredients:\n" + ", ".join(optimized_ingredients) return title, ingredients_formatted, directions_formatted, used_ingredients except Exception as e: return f"❌ **Error:** {str(e)}", "", "", "" def generate_recipe_api(required_ingredients_text, available_ingredients_text, max_ingredients): """API-compatible function that returns JSON format""" try: # Parse ingredient inputs required_ingredients = [] available_ingredients = [] if required_ingredients_text: required_ingredients = [ing.strip() for ing in required_ingredients_text.split(',') if ing.strip()] if available_ingredients_text: available_ingredients = [ing.strip() for ing in available_ingredients_text.split(',') if ing.strip()] # Validate inputs if not required_ingredients and not available_ingredients: return json.dumps({"error": "No ingredients provided"}, indent=2) # Find best ingredient combination optimized_ingredients = find_best_ingredients( required_ingredients, available_ingredients, max_ingredients ) # Generate recipe recipe = generate_recipe_with_t5(optimized_ingredients) # Format for API response api_response = { 'title': recipe['title'], 'ingredients': recipe['ingredients'], 'directions': recipe['directions'], 'used_ingredients': optimized_ingredients } return json.dumps(api_response, indent=2, ensure_ascii=False) except Exception as e: return json.dumps({"error": f"Error in recipe generation: {str(e)}"}, indent=2) # Create Gradio interface with gr.Blocks(title="🍳 AI Recipe Generator", theme=gr.themes.Soft()) as demo: gr.Markdown(""" # 🍳 AI Recipe Generator Generate delicious recipes using AI! This tool uses **RecipeBERT** to find the best ingredient combinations and **T5** to generate complete recipes. ## How to use: 1. **Required Ingredients:** Enter ingredients you must use (comma-separated) 2. **Available Ingredients:** Enter additional ingredients you have available (comma-separated) 3. **Max Ingredients:** Set the maximum number of ingredients for your recipe 4. Click **Generate Recipe** to create your personalized recipe! """) with gr.Tab("🍽️ Recipe Generator"): with gr.Row(): with gr.Column(): required_ingredients = gr.Textbox( label="🎯 Required Ingredients", placeholder="chicken, rice, onions", info="Ingredients that must be included in the recipe (comma-separated)" ) available_ingredients = gr.Textbox( label="🥕 Available Ingredients", placeholder="garlic, tomatoes, basil, cheese", info="Additional ingredients you have available (comma-separated)" ) max_ingredients = gr.Slider( minimum=3, maximum=12, value=7, step=1, label="📊 Maximum Ingredients", info="Maximum number of ingredients to use in the recipe" ) generate_btn = gr.Button("🚀 Generate Recipe", variant="primary", size="lg") with gr.Column(): recipe_title = gr.Markdown() used_ingredients = gr.Markdown() with gr.Row(): with gr.Column(): recipe_ingredients = gr.Markdown() with gr.Column(): recipe_directions = gr.Markdown() with gr.Tab("🔌 API Format"): gr.Markdown(""" ## API Response Format This tab shows the response in JSON format, compatible with your Flutter app. """) with gr.Row(): with gr.Column(): api_required = gr.Textbox( label="Required Ingredients", placeholder="chicken, rice, onions" ) api_available = gr.Textbox( label="Available Ingredients", placeholder="garlic, tomatoes, basil" ) api_max = gr.Slider( minimum=3, maximum=12, value=7, step=1, label="Max Ingredients" ) api_generate_btn = gr.Button("Generate JSON", variant="secondary") with gr.Column(): api_output = gr.Code(language="json", label="API Response") # Event handlers - WICHTIG: Diese erstellen die API-Endpunkte! generate_btn.click( fn=generate_recipe_interface, inputs=[required_ingredients, available_ingredients, max_ingredients], outputs=[recipe_title, recipe_ingredients, recipe_directions, used_ingredients] ) api_generate_btn.click( fn=generate_recipe_api, inputs=[api_required, api_available, api_max], outputs=[api_output] ) # Example inputs gr.Examples( examples=[ ["chicken, rice", "onions, garlic, tomatoes, basil", 6], ["eggs, flour", "milk, sugar, vanilla, butter", 7], ["salmon", "lemon, dill, potatoes, asparagus", 5], ["", "beef, potatoes, carrots, onions, garlic", 6] ], inputs=[required_ingredients, available_ingredients, max_ingredients] ) # Launch the application if __name__ == "__main__": demo.launch( server_name="0.0.0.0", server_port=7860, show_error=True )