| import gradio as gr |
| from transformers import FlaxAutoModelForSeq2SeqLM, AutoTokenizer, AutoModel |
| import torch |
| import numpy as np |
| import random |
| import json |
|
|
| |
| bert_model_name = "alexdseo/RecipeBERT" |
| bert_tokenizer = AutoTokenizer.from_pretrained(bert_model_name) |
| bert_model = AutoModel.from_pretrained(bert_model_name) |
| bert_model.eval() |
|
|
| |
| MODEL_NAME_OR_PATH = "flax-community/t5-recipe-generation" |
| t5_tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME_OR_PATH, use_fast=True) |
| t5_model = FlaxAutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME_OR_PATH) |
|
|
| |
| special_tokens = t5_tokenizer.all_special_tokens |
| tokens_map = { |
| "<sep>": "--", |
| "<section>": "\n" |
| } |
|
|
| def get_embedding(text): |
| """Berechnet das Embedding für einen Text mit Mean Pooling über alle Tokens""" |
| inputs = bert_tokenizer(text, return_tensors="pt", truncation=True, padding=True) |
| with torch.no_grad(): |
| outputs = bert_model(**inputs) |
|
|
| |
| attention_mask = inputs['attention_mask'] |
| token_embeddings = outputs.last_hidden_state |
| input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() |
| sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1) |
| sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9) |
|
|
| return (sum_embeddings / sum_mask).squeeze(0) |
|
|
| def average_embedding(embedding_list): |
| """Berechnet den Durchschnitt einer Liste von Embeddings""" |
| tensors = torch.stack([emb for _, emb in embedding_list]) |
| return tensors.mean(dim=0) |
|
|
| def get_cosine_similarity(vec1, vec2): |
| """Berechnet die Cosinus-Ähnlichkeit zwischen zwei Vektoren""" |
| if torch.is_tensor(vec1): |
| vec1 = vec1.detach().numpy() |
| if torch.is_tensor(vec2): |
| vec2 = vec2.detach().numpy() |
|
|
| |
| vec1 = vec1.flatten() |
| vec2 = vec2.flatten() |
|
|
| dot_product = np.dot(vec1, vec2) |
| norm_a = np.linalg.norm(vec1) |
| norm_b = np.linalg.norm(vec2) |
|
|
| |
| if norm_a == 0 or norm_b == 0: |
| return 0 |
|
|
| return dot_product / (norm_a * norm_b) |
|
|
| def get_combined_scores(query_vector, embedding_list, all_good_embeddings, avg_weight=0.6): |
| """Berechnet einen kombinierten Score unter Berücksichtigung der Ähnlichkeit zum Durchschnitt und zu einzelnen Zutaten""" |
| results = [] |
|
|
| for name, emb in embedding_list: |
| |
| avg_similarity = get_cosine_similarity(query_vector, emb) |
|
|
| |
| individual_similarities = [get_cosine_similarity(good_emb, emb) |
| for _, good_emb in all_good_embeddings] |
| avg_individual_similarity = sum(individual_similarities) / len(individual_similarities) |
|
|
| |
| combined_score = avg_weight * avg_similarity + (1 - avg_weight) * avg_individual_similarity |
|
|
| results.append((name, emb, combined_score)) |
|
|
| |
| results.sort(key=lambda x: x[2], reverse=True) |
| return results |
|
|
| def find_best_ingredients(required_ingredients, available_ingredients, max_ingredients=6, avg_weight=0.6): |
| """ |
| Findet die besten Zutaten basierend auf RecipeBERT Embeddings. |
| """ |
| |
| required_ingredients = list(set(required_ingredients)) |
| available_ingredients = list(set([i for i in available_ingredients if i not in required_ingredients])) |
|
|
| |
| if not required_ingredients and available_ingredients: |
| random_ingredient = random.choice(available_ingredients) |
| required_ingredients = [random_ingredient] |
| available_ingredients = [i for i in available_ingredients if i != random_ingredient] |
| |
|
|
| |
| if not required_ingredients or len(required_ingredients) >= max_ingredients: |
| return required_ingredients[:max_ingredients] |
|
|
| |
| if not available_ingredients: |
| return required_ingredients |
|
|
| |
| embed_required = [(e, get_embedding(e)) for e in required_ingredients] |
| embed_available = [(e, get_embedding(e)) for e in available_ingredients] |
|
|
| |
| num_to_add = min(max_ingredients - len(required_ingredients), len(available_ingredients)) |
|
|
| |
| final_ingredients = embed_required.copy() |
|
|
| |
| for _ in range(num_to_add): |
| |
| avg = average_embedding(final_ingredients) |
|
|
| |
| candidates = get_combined_scores(avg, embed_available, final_ingredients, avg_weight) |
|
|
| |
| if not candidates: |
| break |
|
|
| |
| best_name, best_embedding, _ = candidates[0] |
|
|
| |
| final_ingredients.append((best_name, best_embedding)) |
|
|
| |
| embed_available = [item for item in embed_available if item[0] != best_name] |
|
|
| |
| return [name for name, _ in final_ingredients] |
|
|
| def skip_special_tokens(text, special_tokens): |
| """Entfernt spezielle Tokens aus dem Text""" |
| for token in special_tokens: |
| text = text.replace(token, "") |
| return text |
|
|
| def target_postprocessing(texts, special_tokens): |
| """Post-processed generierten Text""" |
| if not isinstance(texts, list): |
| texts = [texts] |
|
|
| new_texts = [] |
| for text in texts: |
| text = skip_special_tokens(text, special_tokens) |
|
|
| for k, v in tokens_map.items(): |
| text = text.replace(k, v) |
|
|
| new_texts.append(text) |
|
|
| return new_texts |
|
|
| def validate_recipe_ingredients(recipe_ingredients, expected_ingredients, tolerance=0): |
| """ |
| Validiert, ob das Rezept ungefähr die erwarteten Zutaten enthält. |
| """ |
| recipe_count = len([ing for ing in recipe_ingredients if ing and ing.strip()]) |
| expected_count = len(expected_ingredients) |
| return abs(recipe_count - expected_count) == tolerance |
|
|
| def generate_recipe_with_t5(ingredients_list, max_retries=5): |
| """Generiert ein Rezept mit dem T5 Rezeptgenerierungsmodell mit Validierung.""" |
| original_ingredients = ingredients_list.copy() |
|
|
| for attempt in range(max_retries): |
| try: |
| |
| if attempt > 0: |
| current_ingredients = original_ingredients.copy() |
| random.shuffle(current_ingredients) |
| else: |
| current_ingredients = ingredients_list |
|
|
| |
| ingredients_string = ", ".join(current_ingredients) |
| prefix = "items: " |
|
|
| |
| generation_kwargs = { |
| "max_length": 512, |
| "min_length": 64, |
| "do_sample": True, |
| "top_k": 60, |
| "top_p": 0.95 |
| } |
| |
|
|
| |
| inputs = t5_tokenizer( |
| prefix + ingredients_string, |
| max_length=256, |
| padding="max_length", |
| truncation=True, |
| return_tensors="jax" |
| ) |
|
|
| |
| output_ids = t5_model.generate( |
| input_ids=inputs.input_ids, |
| attention_mask=inputs.attention_mask, |
| **generation_kwargs |
| ) |
|
|
| |
| generated = output_ids.sequences |
| generated_text = target_postprocessing( |
| t5_tokenizer.batch_decode(generated, skip_special_tokens=False), |
| special_tokens |
| )[0] |
|
|
| |
| recipe = {} |
| sections = generated_text.split("\n") |
| for section in sections: |
| section = section.strip() |
| if section.startswith("title:"): |
| recipe["title"] = section.replace("title:", "").strip().capitalize() |
| elif section.startswith("ingredients:"): |
| ingredients_text = section.replace("ingredients:", "").strip() |
| recipe["ingredients"] = [item.strip().capitalize() for item in ingredients_text.split("--") if item.strip()] |
| elif section.startswith("directions:"): |
| directions_text = section.replace("directions:", "").strip() |
| recipe["directions"] = [step.strip().capitalize() for step in directions_text.split("--") if step.strip()] |
|
|
| |
| if "title" not in recipe: |
| recipe["title"] = f"Rezept mit {', '.join(current_ingredients[:3])}" |
|
|
| |
| if "ingredients" not in recipe: |
| recipe["ingredients"] = current_ingredients |
| if "directions" not in recipe: |
| recipe["directions"] = ["Keine Anweisungen generiert"] |
|
|
| |
| if validate_recipe_ingredients(recipe["ingredients"], original_ingredients): |
| |
| return recipe |
| else: |
| |
| if attempt == max_retries - 1: |
| |
| return recipe |
|
|
| except Exception as e: |
| |
| if attempt == max_retries - 1: |
| return { |
| "title": f"Rezept mit {original_ingredients[0] if original_ingredients else 'Zutaten'}", |
| "ingredients": original_ingredients, |
| "directions": ["Fehler beim Generieren der Rezeptanweisungen"] |
| } |
|
|
| |
| return { |
| "title": f"Rezept mit {original_ingredients[0] if original_ingredients else 'Zutaten'}", |
| "ingredients": original_ingredients, |
| "directions": ["Fehler beim Generieren der Rezeptanweisungen"] |
| } |
|
|
| |
| |
| def process_recipe_request_logic(required_ingredients, available_ingredients, max_ingredients, max_retries): |
| """ |
| Kernlogik zur Verarbeitung einer Rezeptgenerierungsanfrage. |
| Ausgelagert, um von verschiedenen Endpunkten aufgerufen zu werden. |
| """ |
| if not required_ingredients and not available_ingredients: |
| return {"error": "Keine Zutaten angegeben"} |
|
|
| try: |
| |
| optimized_ingredients = find_best_ingredients( |
| required_ingredients, |
| available_ingredients, |
| max_ingredients |
| ) |
|
|
| |
| recipe = generate_recipe_with_t5(optimized_ingredients, max_retries) |
|
|
| |
| result = { |
| 'title': recipe['title'], |
| 'ingredients': recipe['ingredients'], |
| 'directions': recipe['directions'], |
| 'used_ingredients': optimized_ingredients |
| } |
| return result |
|
|
| except Exception as e: |
| return {"error": f"Fehler bei der Rezeptgenerierung: {str(e)}"} |
|
|
| def flutter_api_generate_recipe(ingredients_data: str): |
| """ |
| Diese Funktion wird vom 'hugging_face_chat_gradio'-Paket über die API aufgerufen. |
| Sie erwartet einen JSON-STRING als Eingabe. |
| """ |
| try: |
| |
| data = json.loads(ingredients_data) |
|
|
| required_ingredients = data.get('required_ingredients', []) |
| available_ingredients = data.get('available_ingredients', []) |
| max_ingredients = data.get('max_ingredients', 7) |
| max_retries = data.get('max_retries', 5) |
|
|
| |
| result_dict = process_recipe_request_logic( |
| required_ingredients, available_ingredients, max_ingredients, max_retries |
| ) |
| return json.dumps(result_dict) |
|
|
| except Exception as e: |
| |
| print(f"Error in flutter_api_generate_recipe: {str(e)}") |
| return json.dumps({"error": f"Internal API Error: {str(e)}"}) |
|
|
| def gradio_ui_generate_recipe(required_ingredients_text, available_ingredients_text, max_ingredients_val, max_retries_val): |
| """Gradio UI Funktion für die Web-Oberfläche""" |
| try: |
| required_ingredients = [ing.strip() for ing in required_ingredients_text.split(',') if ing.strip()] |
| available_ingredients = [ing.strip() for ing in available_ingredients_text.split(',') if ing.strip()] |
|
|
| |
| result = process_recipe_request_logic( |
| required_ingredients, available_ingredients, max_ingredients_val, max_retries_val |
| ) |
|
|
| if 'error' in result: |
| return result['error'], "", "", "" |
|
|
| ingredients_list = '\n'.join([f"• {ing}" for ing in result['ingredients']]) |
| directions_list = '\n'.join([f"{i+1}. {dir}" for i, dir in enumerate(result['directions'])]) |
| used_ingredients = ', '.join(result['used_ingredients']) |
|
|
| return ( |
| result['title'], |
| ingredients_list, |
| directions_list, |
| used_ingredients |
| ) |
|
|
| except Exception as e: |
| |
| return f"Fehler: {str(e)}", "", "", "" |
|
|
| |
| with gr.Blocks(title="AI Rezept Generator") as demo: |
| gr.Markdown("# 🍳 AI Rezept Generator") |
| gr.Markdown("Generiere Rezepte mit KI und intelligenter Zutat-Kombination!") |
|
|
| with gr.Tab("Web-Oberfläche"): |
| with gr.Row(): |
| with gr.Column(): |
| required_ing = gr.Textbox( |
| label="Benötigte Zutaten (kommasepariert)", |
| placeholder="Hähnchen, Reis, Zwiebel", |
| lines=2 |
| ) |
| available_ing = gr.Textbox( |
| label="Verfügbare Zutaten (kommasepariert, optional)", |
| placeholder="Knoblauch, Tomate, Pfeffer, Kräuter", |
| lines=2 |
| ) |
| max_ing = gr.Slider(3, 10, value=7, step=1, label="Maximale Zutaten") |
| max_retries = gr.Slider(1, 10, value=5, step=1, label="Max. Wiederholungsversuche") |
|
|
| generate_btn = gr.Button("Rezept generieren", variant="primary") |
|
|
| with gr.Column(): |
| title_output = gr.Textbox(label="Rezepttitel", interactive=False) |
| ingredients_output = gr.Textbox(label="Zutaten", lines=8, interactive=False) |
| directions_output = gr.Textbox(label="Anweisungen", lines=10, interactive=False) |
| used_ingredients_output = gr.Textbox(label="Verwendete Zutaten", interactive=False) |
|
|
| generate_btn.click( |
| fn=gradio_ui_generate_recipe, |
| inputs=[required_ing, available_ing, max_ing, max_retries], |
| outputs=[title_output, ingredients_output, directions_output, used_ingredients_output] |
| ) |
|
|
| with gr.Tab("API-Test"): |
| gr.Markdown("### Teste die Flutter API (via 'hugging_face_chat_gradio' Client)") |
| gr.Markdown("Dieser Tab zeigt, wie die Eingabe für die 'generate_recipe_for_flutter'-API aussehen sollte.") |
|
|
| api_input = gr.Textbox( |
| label="JSON-Eingabe (für API-Aufruf)", |
| placeholder='{"required_ingredients": ["chicken", "rice"], "available_ingredients": ["onion", "garlic"], "max_ingredients": 6}', |
| lines=4 |
| ) |
| api_output = gr.Textbox(label="JSON-Ausgabe", lines=15, interactive=False) |
| api_test_btn = gr.Button("API testen", variant="secondary") |
|
|
| api_test_btn.click( |
| fn=flutter_api_generate_recipe, |
| inputs=[api_input], |
| outputs=[api_output], |
| api_name="generate_recipe_for_flutter" |
| ) |
|
|
| gr.Examples( |
| examples=[ |
| ['{"required_ingredients": ["chicken", "rice"], "available_ingredients": ["onion", "garlic", "tomato"], "max_ingredients": 6}'], |
| ['{"ingredients": ["pasta"], "available_ingredients": ["cheese", "mushrooms", "cream"], "max_ingredients": 5}'] |
| ], |
| inputs=[api_input] |
| ) |
|
|
| |
| if __name__ == "__main__": |
| demo.launch() |
|
|