import gradio as gr from transformers import FlaxAutoModelForSeq2SeqLM, AutoTokenizer, AutoModel import torch import numpy as np import random import json # Beibehalten, da es in flutter_api_generate_recipe verwendet wird # Lade RecipeBERT Modell (für semantische Zutat-Kombination) bert_model_name = "alexdseo/RecipeBERT" bert_tokenizer = AutoTokenizer.from_pretrained(bert_model_name) bert_model = AutoModel.from_pretrained(bert_model_name) bert_model.eval() # Setze das Modell in den Evaluationsmodus # Lade T5 Rezeptgenerierungsmodell MODEL_NAME_OR_PATH = "flax-community/t5-recipe-generation" t5_tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME_OR_PATH, use_fast=True) t5_model = FlaxAutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME_OR_PATH) # Token Mapping für die T5 Modell-Ausgabe special_tokens = t5_tokenizer.all_special_tokens tokens_map = { "": "--", "
": "\n" } def get_embedding(text): """Berechnet das Embedding für einen Text mit Mean Pooling über alle Tokens""" inputs = bert_tokenizer(text, return_tensors="pt", truncation=True, padding=True) with torch.no_grad(): outputs = bert_model(**inputs) # Mean Pooling - Mittelwert aller Token-Embeddings attention_mask = inputs['attention_mask'] token_embeddings = outputs.last_hidden_state input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1) sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9) return (sum_embeddings / sum_mask).squeeze(0) def average_embedding(embedding_list): """Berechnet den Durchschnitt einer Liste von Embeddings""" tensors = torch.stack([emb for _, emb in embedding_list]) return tensors.mean(dim=0) def get_cosine_similarity(vec1, vec2): """Berechnet die Cosinus-Ähnlichkeit zwischen zwei Vektoren""" if torch.is_tensor(vec1): vec1 = vec1.detach().numpy() if torch.is_tensor(vec2): vec2 = vec2.detach().numpy() # Stelle sicher, dass die Vektoren die richtige Form haben (flachen sie bei Bedarf ab) vec1 = vec1.flatten() vec2 = vec2.flatten() dot_product = np.dot(vec1, vec2) norm_a = np.linalg.norm(vec1) norm_b = np.linalg.norm(vec2) # Division durch Null vermeiden if norm_a == 0 or norm_b == 0: return 0 return dot_product / (norm_a * norm_b) def get_combined_scores(query_vector, embedding_list, all_good_embeddings, avg_weight=0.6): """Berechnet einen kombinierten Score unter Berücksichtigung der Ähnlichkeit zum Durchschnitt und zu einzelnen Zutaten""" results = [] for name, emb in embedding_list: # Ähnlichkeit zum Durchschnittsvektor avg_similarity = get_cosine_similarity(query_vector, emb) # Durchschnittliche Ähnlichkeit zu einzelnen Zutaten individual_similarities = [get_cosine_similarity(good_emb, emb) for _, good_emb in all_good_embeddings] avg_individual_similarity = sum(individual_similarities) / len(individual_similarities) # Kombinierter Score (gewichteter Durchschnitt) combined_score = avg_weight * avg_similarity + (1 - avg_weight) * avg_individual_similarity results.append((name, emb, combined_score)) # Sortiere nach kombiniertem Score (absteigend) results.sort(key=lambda x: x[2], reverse=True) return results def find_best_ingredients(required_ingredients, available_ingredients, max_ingredients=6, avg_weight=0.6): """ Findet die besten Zutaten basierend auf RecipeBERT Embeddings. """ # Stelle sicher, dass keine Duplikate in den Listen sind required_ingredients = list(set(required_ingredients)) available_ingredients = list(set([i for i in available_ingredients if i not in required_ingredients])) # Sonderfall: Wenn keine benötigten Zutaten vorhanden sind, wähle zufällig eine aus den verfügbaren Zutaten if not required_ingredients and available_ingredients: random_ingredient = random.choice(available_ingredients) required_ingredients = [random_ingredient] available_ingredients = [i for i in available_ingredients if i != random_ingredient] # print(f"Keine benötigten Zutaten angegeben. Zufällig ausgewählt: {random_ingredient}") # Wenn immer noch keine Zutaten vorhanden oder bereits maximale Kapazität erreicht ist if not required_ingredients or len(required_ingredients) >= max_ingredients: return required_ingredients[:max_ingredients] # Wenn keine zusätzlichen Zutaten verfügbar sind if not available_ingredients: return required_ingredients # Berechne Embeddings für alle Zutaten embed_required = [(e, get_embedding(e)) for e in required_ingredients] embed_available = [(e, get_embedding(e)) for e in available_ingredients] # Anzahl der hinzuzufügenden Zutaten num_to_add = min(max_ingredients - len(required_ingredients), len(available_ingredients)) # Kopiere benötigte Zutaten in die endgültige Liste final_ingredients = embed_required.copy() # Füge die besten Zutaten hinzu for _ in range(num_to_add): # Berechne den Durchschnittsvektor der aktuellen Kombination avg = average_embedding(final_ingredients) # Berechne kombinierte Scores für alle Kandidaten candidates = get_combined_scores(avg, embed_available, final_ingredients, avg_weight) # Wenn keine Kandidaten mehr übrig sind, breche ab if not candidates: break # Wähle die beste Zutat best_name, best_embedding, _ = candidates[0] # Füge die beste Zutat zur endgültigen Liste hinzu final_ingredients.append((best_name, best_embedding)) # Entferne die Zutat aus den verfügbaren Zutaten embed_available = [item for item in embed_available if item[0] != best_name] # Extrahiere nur die Zutatennamen return [name for name, _ in final_ingredients] def skip_special_tokens(text, special_tokens): """Entfernt spezielle Tokens aus dem Text""" for token in special_tokens: text = text.replace(token, "") return text def target_postprocessing(texts, special_tokens): """Post-processed generierten Text""" if not isinstance(texts, list): texts = [texts] new_texts = [] for text in texts: text = skip_special_tokens(text, special_tokens) for k, v in tokens_map.items(): text = text.replace(k, v) new_texts.append(text) return new_texts def validate_recipe_ingredients(recipe_ingredients, expected_ingredients, tolerance=0): """ Validiert, ob das Rezept ungefähr die erwarteten Zutaten enthält. """ recipe_count = len([ing for ing in recipe_ingredients if ing and ing.strip()]) expected_count = len(expected_ingredients) return abs(recipe_count - expected_count) == tolerance def generate_recipe_with_t5(ingredients_list, max_retries=5): """Generiert ein Rezept mit dem T5 Rezeptgenerierungsmodell mit Validierung.""" original_ingredients = ingredients_list.copy() for attempt in range(max_retries): try: # Für Wiederholungsversuche nach dem ersten Versuch, mische die Zutaten if attempt > 0: current_ingredients = original_ingredients.copy() random.shuffle(current_ingredients) else: current_ingredients = ingredients_list # Formatiere Zutaten als kommaseparierten String ingredients_string = ", ".join(current_ingredients) prefix = "items: " # Generationseinstellungen generation_kwargs = { "max_length": 512, "min_length": 64, "do_sample": True, "top_k": 60, "top_p": 0.95 } # print(f"Versuch {attempt + 1}: {prefix + ingredients_string}") # Tokenisiere Eingabe inputs = t5_tokenizer( prefix + ingredients_string, max_length=256, padding="max_length", truncation=True, return_tensors="jax" ) # Generiere Text output_ids = t5_model.generate( input_ids=inputs.input_ids, attention_mask=inputs.attention_mask, **generation_kwargs ) # Dekodieren und Nachbearbeiten generated = output_ids.sequences generated_text = target_postprocessing( t5_tokenizer.batch_decode(generated, skip_special_tokens=False), special_tokens )[0] # Abschnitte parsen recipe = {} sections = generated_text.split("\n") for section in sections: section = section.strip() if section.startswith("title:"): recipe["title"] = section.replace("title:", "").strip().capitalize() elif section.startswith("ingredients:"): ingredients_text = section.replace("ingredients:", "").strip() recipe["ingredients"] = [item.strip().capitalize() for item in ingredients_text.split("--") if item.strip()] elif section.startswith("directions:"): directions_text = section.replace("directions:", "").strip() recipe["directions"] = [step.strip().capitalize() for step in directions_text.split("--") if step.strip()] # Wenn der Titel fehlt, erstelle einen if "title" not in recipe: recipe["title"] = f"Rezept mit {', '.join(current_ingredients[:3])}" # Stelle sicher, dass alle Abschnitte existieren if "ingredients" not in recipe: recipe["ingredients"] = current_ingredients if "directions" not in recipe: recipe["directions"] = ["Keine Anweisungen generiert"] # Validiere das Rezept if validate_recipe_ingredients(recipe["ingredients"], original_ingredients): # print(f"Erfolg bei Versuch {attempt + 1}: Rezept hat die richtige Anzahl von Zutaten") return recipe else: # print(f"Versuch {attempt + 1} fehlgeschlagen: Erwartet {len(original_ingredients)} Zutaten, erhalten {len(recipe['ingredients'])}") if attempt == max_retries - 1: # print("Maximale Wiederholungsversuche erreicht, letztes generiertes Rezept wird zurückgegeben") return recipe except Exception as e: # print(f"Fehler bei der Rezeptgenerierung Versuch {attempt + 1}: {str(e)}") if attempt == max_retries - 1: return { "title": f"Rezept mit {original_ingredients[0] if original_ingredients else 'Zutaten'}", "ingredients": original_ingredients, "directions": ["Fehler beim Generieren der Rezeptanweisungen"] } # Fallback (sollte nicht erreicht werden) return { "title": f"Rezept mit {original_ingredients[0] if original_ingredients else 'Zutaten'}", "ingredients": original_ingredients, "directions": ["Fehler beim Generieren der Rezeptanweisungen"] } # Diese Funktion wird von der Gradio-UI und der FastAPI-Route aufgerufen. # Sie ist für die Kernlogik zuständig. def process_recipe_request_logic(required_ingredients, available_ingredients, max_ingredients, max_retries): """ Kernlogik zur Verarbeitung einer Rezeptgenerierungsanfrage. Ausgelagert, um von verschiedenen Endpunkten aufgerufen zu werden. """ if not required_ingredients and not available_ingredients: return {"error": "Keine Zutaten angegeben"} try: # Optimale Zutaten finden optimized_ingredients = find_best_ingredients( required_ingredients, available_ingredients, max_ingredients ) # Rezept mit optimierten Zutaten generieren recipe = generate_recipe_with_t5(optimized_ingredients, max_retries) # Ergebnis formatieren result = { 'title': recipe['title'], 'ingredients': recipe['ingredients'], 'directions': recipe['directions'], 'used_ingredients': optimized_ingredients } return result except Exception as e: return {"error": f"Fehler bei der Rezeptgenerierung: {str(e)}"} def flutter_api_generate_recipe(ingredients_data: str): # Typ-Hint für Klarheit """ Diese Funktion wird vom 'hugging_face_chat_gradio'-Paket über die API aufgerufen. Sie erwartet einen JSON-STRING als Eingabe. """ try: # Der 'hugging_face_chat_gradio'-Client sendet das Payload als String. data = json.loads(ingredients_data) required_ingredients = data.get('required_ingredients', []) available_ingredients = data.get('available_ingredients', []) max_ingredients = data.get('max_ingredients', 7) max_retries = data.get('max_retries', 5) # Rufe die Kernlogik auf result_dict = process_recipe_request_logic( required_ingredients, available_ingredients, max_ingredients, max_retries ) return json.dumps(result_dict) # Gibt einen JSON-STRING zurück except Exception as e: # Logge den Fehler für Debugging im Space-Log print(f"Error in flutter_api_generate_recipe: {str(e)}") return json.dumps({"error": f"Internal API Error: {str(e)}"}) def gradio_ui_generate_recipe(required_ingredients_text, available_ingredients_text, max_ingredients_val, max_retries_val): """Gradio UI Funktion für die Web-Oberfläche""" try: required_ingredients = [ing.strip() for ing in required_ingredients_text.split(',') if ing.strip()] available_ingredients = [ing.strip() for ing in available_ingredients_text.split(',') if ing.strip()] # Rufe die Kernlogik auf result = process_recipe_request_logic( required_ingredients, available_ingredients, max_ingredients_val, max_retries_val ) if 'error' in result: return result['error'], "", "", "" ingredients_list = '\n'.join([f"• {ing}" for ing in result['ingredients']]) directions_list = '\n'.join([f"{i+1}. {dir}" for i, dir in enumerate(result['directions'])]) used_ingredients = ', '.join(result['used_ingredients']) return ( result['title'], ingredients_list, directions_list, used_ingredients ) except Exception as e: # Fehlermeldung für die Gradio UI return f"Fehler: {str(e)}", "", "", "" # Erstelle die Gradio Oberfläche with gr.Blocks(title="AI Rezept Generator") as demo: gr.Markdown("# 🍳 AI Rezept Generator") gr.Markdown("Generiere Rezepte mit KI und intelligenter Zutat-Kombination!") with gr.Tab("Web-Oberfläche"): with gr.Row(): with gr.Column(): required_ing = gr.Textbox( label="Benötigte Zutaten (kommasepariert)", placeholder="Hähnchen, Reis, Zwiebel", lines=2 ) available_ing = gr.Textbox( label="Verfügbare Zutaten (kommasepariert, optional)", placeholder="Knoblauch, Tomate, Pfeffer, Kräuter", lines=2 ) max_ing = gr.Slider(3, 10, value=7, step=1, label="Maximale Zutaten") max_retries = gr.Slider(1, 10, value=5, step=1, label="Max. Wiederholungsversuche") generate_btn = gr.Button("Rezept generieren", variant="primary") with gr.Column(): title_output = gr.Textbox(label="Rezepttitel", interactive=False) ingredients_output = gr.Textbox(label="Zutaten", lines=8, interactive=False) directions_output = gr.Textbox(label="Anweisungen", lines=10, interactive=False) used_ingredients_output = gr.Textbox(label="Verwendete Zutaten", interactive=False) generate_btn.click( fn=gradio_ui_generate_recipe, inputs=[required_ing, available_ing, max_ing, max_retries], outputs=[title_output, ingredients_output, directions_output, used_ingredients_output] ) with gr.Tab("API-Test"): gr.Markdown("### Teste die Flutter API (via 'hugging_face_chat_gradio' Client)") gr.Markdown("Dieser Tab zeigt, wie die Eingabe für die 'generate_recipe_for_flutter'-API aussehen sollte.") api_input = gr.Textbox( label="JSON-Eingabe (für API-Aufruf)", placeholder='{"required_ingredients": ["chicken", "rice"], "available_ingredients": ["onion", "garlic"], "max_ingredients": 6}', lines=4 ) api_output = gr.Textbox(label="JSON-Ausgabe", lines=15, interactive=False) api_test_btn = gr.Button("API testen", variant="secondary") api_test_btn.click( fn=flutter_api_generate_recipe, inputs=[api_input], outputs=[api_output], api_name="generate_recipe_for_flutter" # Dies ist der api_name, den das Flutter-Paket verwendet ) gr.Examples( examples=[ ['{"required_ingredients": ["chicken", "rice"], "available_ingredients": ["onion", "garlic", "tomato"], "max_ingredients": 6}'], ['{"ingredients": ["pasta"], "available_ingredients": ["cheese", "mushrooms", "cream"], "max_ingredients": 5}'] ], inputs=[api_input] ) # Gradio-App starten if __name__ == "__main__": demo.launch()