ReGe / app.py
TimInf's picture
Update app.py
88f38f2 verified
raw
history blame
19.5 kB
import gradio as gr
from transformers import FlaxAutoModelForSeq2SeqLM, AutoTokenizer, AutoModel
import torch
import numpy as np
import random
import json
from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse
from pydantic import BaseModel
# Lade RecipeBERT Modell (für semantische Zutat-Kombination)
bert_model_name = "alexdseo/RecipeBERT"
bert_tokenizer = AutoTokenizer.from_pretrained(bert_model_name)
bert_model = AutoModel.from_pretrained(bert_model_name)
bert_model.eval() # Setze das Modell in den Evaluationsmodus
# Lade T5 Rezeptgenerierungsmodell
MODEL_NAME_OR_PATH = "flax-community/t5-recipe-generation"
t5_tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME_OR_PATH, use_fast=True)
t5_model = FlaxAutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME_OR_PATH)
# Token Mapping für die T5 Modell-Ausgabe
special_tokens = t5_tokenizer.all_special_tokens
tokens_map = {
"<sep>": "--",
"<section>": "\n"
}
def get_embedding(text):
"""Berechnet das Embedding für einen Text mit Mean Pooling über alle Tokens"""
inputs = bert_tokenizer(text, return_tensors="pt", truncation=True, padding=True)
with torch.no_grad():
outputs = bert_model(**inputs)
# Mean Pooling - Mittelwert aller Token-Embeddings
attention_mask = inputs['attention_mask']
token_embeddings = outputs.last_hidden_state
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1)
sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
return (sum_embeddings / sum_mask).squeeze(0)
def average_embedding(embedding_list):
"""Berechnet den Durchschnitt einer Liste von Embeddings"""
tensors = torch.stack([emb for _, emb in embedding_list])
return tensors.mean(dim=0)
def get_cosine_similarity(vec1, vec2):
"""Berechnet die Cosinus-Ähnlichkeit zwischen zwei Vektoren"""
if torch.is_tensor(vec1):
vec1 = vec1.detach().numpy()
if torch.is_tensor(vec2):
vec2 = vec2.detach().numpy()
# Stelle sicher, dass die Vektoren die richtige Form haben (flachen sie bei Bedarf ab)
vec1 = vec1.flatten()
vec2 = vec2.flatten()
dot_product = np.dot(vec1, vec2)
norm_a = np.linalg.norm(vec1)
norm_b = np.linalg.norm(vec2)
# Division durch Null vermeiden
if norm_a == 0 or norm_b == 0:
return 0
return dot_product / (norm_a * norm_b)
def get_combined_scores(query_vector, embedding_list, all_good_embeddings, avg_weight=0.6):
"""Berechnet einen kombinierten Score unter Berücksichtigung der Ähnlichkeit zum Durchschnitt und zu einzelnen Zutaten"""
results = []
for name, emb in embedding_list:
# Ähnlichkeit zum Durchschnittsvektor
avg_similarity = get_cosine_similarity(query_vector, emb)
# Durchschnittliche Ähnlichkeit zu einzelnen Zutaten
individual_similarities = [get_cosine_similarity(good_emb, emb)
for _, good_emb in all_good_embeddings]
avg_individual_similarity = sum(individual_similarities) / len(individual_similarities)
# Kombinierter Score (gewichteter Durchschnitt)
combined_score = avg_weight * avg_similarity + (1 - avg_weight) * avg_individual_similarity
results.append((name, emb, combined_score))
# Sortiere nach kombiniertem Score (absteigend)
results.sort(key=lambda x: x[2], reverse=True)
return results
def find_best_ingredients(required_ingredients, available_ingredients, max_ingredients=6, avg_weight=0.6):
"""
Findet die besten Zutaten basierend auf RecipeBERT Embeddings.
"""
# Stelle sicher, dass keine Duplikate in den Listen sind
required_ingredients = list(set(required_ingredients))
available_ingredients = list(set([i for i in available_ingredients if i not in required_ingredients]))
# Sonderfall: Wenn keine benötigten Zutaten vorhanden sind, wähle zufällig eine aus den verfügbaren Zutaten
if not required_ingredients and available_ingredients:
random_ingredient = random.choice(available_ingredients)
required_ingredients = [random_ingredient]
available_ingredients = [i for i in available_ingredients if i != random_ingredient]
# print(f"Keine benötigten Zutaten angegeben. Zufällig ausgewählt: {random_ingredient}")
# Wenn immer noch keine Zutaten vorhanden oder bereits maximale Kapazität erreicht ist
if not required_ingredients or len(required_ingredients) >= max_ingredients:
return required_ingredients[:max_ingredients]
# Wenn keine zusätzlichen Zutaten verfügbar sind
if not available_ingredients:
return required_ingredients
# Berechne Embeddings für alle Zutaten
embed_required = [(e, get_embedding(e)) for e in required_ingredients]
embed_available = [(e, get_embedding(e)) for e in available_ingredients]
# Anzahl der hinzuzufügenden Zutaten
num_to_add = min(max_ingredients - len(required_ingredients), len(available_ingredients))
# Kopiere benötigte Zutaten in die endgültige Liste
final_ingredients = embed_required.copy()
# Füge die besten Zutaten hinzu
for _ in range(num_to_add):
# Berechne den Durchschnittsvektor der aktuellen Kombination
avg = average_embedding(final_ingredients)
# Berechne kombinierte Scores für alle Kandidaten
candidates = get_combined_scores(avg, embed_available, final_ingredients, avg_weight)
# Wenn keine Kandidaten mehr übrig sind, breche ab
if not candidates:
break
# Wähle die beste Zutat
best_name, best_embedding, _ = candidates[0]
# Füge die beste Zutat zur endgültigen Liste hinzu
final_ingredients.append((best_name, best_embedding))
# Entferne die Zutat aus den verfügbaren Zutaten
embed_available = [item for item in embed_available if item[0] != best_name]
# Extrahiere nur die Zutatennamen
return [name for name, _ in final_ingredients]
def skip_special_tokens(text, special_tokens):
"""Entfernt spezielle Tokens aus dem Text"""
for token in special_tokens:
text = text.replace(token, "")
return text
def target_postprocessing(texts, special_tokens):
"""Post-processed generierten Text"""
if not isinstance(texts, list):
texts = [texts]
new_texts = []
for text in texts:
text = skip_special_tokens(text, special_tokens)
for k, v in tokens_map.items():
text = text.replace(k, v)
new_texts.append(text)
return new_texts
def validate_recipe_ingredients(recipe_ingredients, expected_ingredients, tolerance=0):
"""
Validiert, ob das Rezept ungefähr die erwarteten Zutaten enthält.
"""
recipe_count = len([ing for ing in recipe_ingredients if ing and ing.strip()])
expected_count = len(expected_ingredients)
return abs(recipe_count - expected_count) == tolerance
def generate_recipe_with_t5(ingredients_list, max_retries=5):
"""Generiert ein Rezept mit dem T5 Rezeptgenerierungsmodell mit Validierung."""
original_ingredients = ingredients_list.copy()
for attempt in range(max_retries):
try:
# Für Wiederholungsversuche nach dem ersten Versuch, mische die Zutaten
if attempt > 0:
current_ingredients = original_ingredients.copy()
random.shuffle(current_ingredients)
else:
current_ingredients = ingredients_list
# Formatiere Zutaten als kommaseparierten String
ingredients_string = ", ".join(current_ingredients)
prefix = "items: "
# Generationseinstellungen
generation_kwargs = {
"max_length": 512,
"min_length": 64,
"do_sample": True,
"top_k": 60,
"top_p": 0.95
}
# print(f"Versuch {attempt + 1}: {prefix + ingredients_string}")
# Tokenisiere Eingabe
inputs = t5_tokenizer(
prefix + ingredients_string,
max_length=256,
padding="max_length",
truncation=True,
return_tensors="jax"
)
# Generiere Text
output_ids = t5_model.generate(
input_ids=inputs.input_ids,
attention_mask=inputs.attention_mask,
**generation_kwargs
)
# Dekodieren und Nachbearbeiten
generated = output_ids.sequences
generated_text = target_postprocessing(
t5_tokenizer.batch_decode(generated, skip_special_tokens=False),
special_tokens
)[0]
# Abschnitte parsen
recipe = {}
sections = generated_text.split("\n")
for section in sections:
section = section.strip()
if section.startswith("title:"):
recipe["title"] = section.replace("title:", "").strip().capitalize()
elif section.startswith("ingredients:"):
ingredients_text = section.replace("ingredients:", "").strip()
recipe["ingredients"] = [item.strip().capitalize() for item in ingredients_text.split("--") if item.strip()]
elif section.startswith("directions:"):
directions_text = section.replace("directions:", "").strip()
recipe["directions"] = [step.strip().capitalize() for step in directions_text.split("--") if step.strip()]
# Wenn der Titel fehlt, erstelle einen
if "title" not in recipe:
recipe["title"] = f"Rezept mit {', '.join(current_ingredients[:3])}"
# Stelle sicher, dass alle Abschnitte existieren
if "ingredients" not in recipe:
recipe["ingredients"] = current_ingredients
if "directions" not in recipe:
recipe["directions"] = ["Keine Anweisungen generiert"]
# Validiere das Rezept
if validate_recipe_ingredients(recipe["ingredients"], original_ingredients):
# print(f"Erfolg bei Versuch {attempt + 1}: Rezept hat die richtige Anzahl von Zutaten")
return recipe
else:
# print(f"Versuch {attempt + 1} fehlgeschlagen: Erwartet {len(original_ingredients)} Zutaten, erhalten {len(recipe['ingredients'])}")
if attempt == max_retries - 1:
# print("Maximale Wiederholungsversuche erreicht, letztes generiertes Rezept wird zurückgegeben")
return recipe
except Exception as e:
# print(f"Fehler bei der Rezeptgenerierung Versuch {attempt + 1}: {str(e)}")
if attempt == max_retries - 1:
return {
"title": f"Rezept mit {original_ingredients[0] if original_ingredients else 'Zutaten'}",
"ingredients": original_ingredients,
"directions": ["Fehler beim Generieren der Rezeptanweisungen"]
}
# Fallback (sollte nicht erreicht werden)
return {
"title": f"Rezept mit {original_ingredients[0] if original_ingredients else 'Zutaten'}",
"ingredients": original_ingredients,
"directions": ["Fehler beim Generieren der Rezeptanweisungen"]
}
# Diese Funktion wird von der Gradio-UI und der neuen FastAPI-Route aufgerufen.
def process_recipe_request_logic(required_ingredients, available_ingredients, max_ingredients, max_retries):
"""
Kernlogik zur Verarbeitung einer Rezeptgenerierungsanfrage.
Ausgelagert, um von verschiedenen Endpunkten aufgerufen zu werden.
"""
if not required_ingredients and not available_ingredients:
return {"error": "Keine Zutaten angegeben"}
try:
# Optimale Zutaten finden
optimized_ingredients = find_best_ingredients(
required_ingredients,
available_ingredients,
max_ingredients
)
# Rezept mit optimierten Zutaten generieren
recipe = generate_recipe_with_t5(optimized_ingredients, max_retries)
# Ergebnis formatieren
result = {
'title': recipe['title'],
'ingredients': recipe['ingredients'],
'directions': recipe['directions'],
'used_ingredients': optimized_ingredients
}
return result
except Exception as e:
return {"error": f"Fehler bei der Rezeptgenerierung: {str(e)}"}
def flutter_api_generate_recipe(ingredients_data):
"""
Flutter-freundliche API-Funktion für den Gradio-API-Test-Tab.
Verarbeitet JSON-String-Eingabe und gibt JSON-String-Ausgabe zurück.
"""
try:
data = json.loads(ingredients_data) # Muss ein JSON-String sein
required_ingredients = data.get('required_ingredients', [])
available_ingredients = data.get('available_ingredients', [])
max_ingredients = data.get('max_ingredients', 7)
max_retries = data.get('max_retries', 5)
# Rufe die Kernlogik auf
result_dict = process_recipe_request_logic(
required_ingredients, available_ingredients, max_ingredients, max_retries
)
return json.dumps(result_dict) # Gibt einen JSON-String zurück
except Exception as e:
return json.dumps({"error": f"Fehler bei der Rezeptgenerierung: {str(e)}"})
def gradio_ui_generate_recipe(required_ingredients_text, available_ingredients_text, max_ingredients_val, max_retries_val):
"""Gradio UI Funktion für die Web-Oberfläche"""
try:
required_ingredients = [ing.strip() for ing in required_ingredients_text.split(',') if ing.strip()]
available_ingredients = [ing.strip() for ing in available_ingredients_text.split(',') if ing.strip()]
# Rufe die Kernlogik auf
result = process_recipe_request_logic(
required_ingredients, available_ingredients, max_ingredients_val, max_retries_val
)
if 'error' in result:
return result['error'], "", "", ""
ingredients_list = '\n'.join([f"• {ing}" for ing in result['ingredients']])
directions_list = '\n'.join([f"{i+1}. {dir}" for i, dir in enumerate(result['directions'])])
used_ingredients = ', '.join(result['used_ingredients'])
return (
result['title'],
ingredients_list,
directions_list,
used_ingredients
)
except Exception as e:
return f"Fehler: {str(e)}", "", "", ""
# Erstelle die Gradio Oberfläche
with gr.Blocks(title="AI Rezept Generator") as demo:
gr.Markdown("# 🍳 AI Rezept Generator")
gr.Markdown("Generiere Rezepte mit KI und intelligenter Zutat-Kombination!")
with gr.Tab("Web-Oberfläche"):
with gr.Row():
with gr.Column():
required_ing = gr.Textbox(
label="Benötigte Zutaten (kommasepariert)",
placeholder="Hähnchen, Reis, Zwiebel",
lines=2
)
available_ing = gr.Textbox(
label="Verfügbare Zutaten (kommasepariert, optional)",
placeholder="Knoblauch, Tomate, Pfeffer, Kräuter",
lines=2
)
max_ing = gr.Slider(3, 10, value=7, step=1, label="Maximale Zutaten")
max_retries = gr.Slider(1, 10, value=5, step=1, label="Max. Wiederholungsversuche")
generate_btn = gr.Button("Rezept generieren", variant="primary")
with gr.Column():
title_output = gr.Textbox(label="Rezepttitel", interactive=False)
ingredients_output = gr.Textbox(label="Zutaten", lines=8, interactive=False)
directions_output = gr.Textbox(label="Anweisungen", lines=10, interactive=False)
used_ingredients_output = gr.Textbox(label="Verwendete Zutaten", interactive=False)
generate_btn.click(
fn=gradio_ui_generate_recipe,
inputs=[required_ing, available_ing, max_ing, max_retries],
outputs=[title_output, ingredients_output, directions_output, used_ingredients_output]
)
with gr.Tab("API-Test"):
gr.Markdown("### Teste die Gradio API (für interne Tests)")
gr.Markdown("Dieser Tab verwendet die interne Gradio API für Testzwecke.")
api_input = gr.Textbox(
label="JSON-Eingabe (Flutter API-Format)",
placeholder='{"required_ingredients": ["chicken", "rice"], "available_ingredients": ["onion", "garlic"], "max_ingredients": 6}',
lines=4
)
api_output = gr.Textbox(label="JSON-Ausgabe", lines=15, interactive=False)
api_test_btn = gr.Button("API testen", variant="secondary")
# Hier wird die Funktion weiterhin für den Gradio-eigenen API-Test-Tab verwendet.
api_test_btn.click(
fn=flutter_api_generate_recipe,
inputs=[api_input],
outputs=[api_output],
api_name="generate_recipe_for_flutter" # Gradio-interner API-Name
)
gr.Examples(
examples=[
['{"required_ingredients": ["chicken", "rice"], "available_ingredients": ["onion", "garlic", "tomato"], "max_ingredients": 6}'],
['{"ingredients": ["pasta"], "available_ingredients": ["cheese", "mushrooms", "cream"], "max_ingredients": 5}']
],
inputs=[api_input]
)
# --- FastAPI-Integration ---
app = FastAPI()
class RecipeRequest(BaseModel):
required_ingredients: list[str] = []
available_ingredients: list[str] = []
max_ingredients: int = 7
max_retries: int = 5
@app.post("/api/generate_recipe_rest")
async def generate_recipe_rest_api(request_data: RecipeRequest):
"""
Standard-REST-API-Endpunkt für die Flutter-App.
Nimmt direkt JSON-Daten an und gibt direkt JSON zurück.
"""
required_ingredients = request_data.required_ingredients
available_ingredients = request_data.available_ingredients
max_ingredients = request_data.max_ingredients
max_retries = request_data.max_retries
# Abwärtskompatibilität, falls 'ingredients' statt 'required_ingredients' gesendet wird
# Dies ist in der FastAPI-Pydantic-Modelldefinition nicht direkt abbildbar,
# aber du könntest es manuell hinzufügen, falls nötig, wenn das Pydantic-Modell flexibler wäre.
# Für den Einfachheit halber gehen wir davon aus, dass Flutter die korrekten Felder sendet.
result_dict = process_recipe_request_logic(
required_ingredients, available_ingredients, max_ingredients, max_retries
)
return JSONResponse(content=result_dict)
# Gradio-App als Sub-App in die FastAPI-App mounten
# Dies ist der Standardweg, um Gradio in eine FastAPI-Anwendung einzubetten.
# Der Gradio-Teil wird dann unter dem Wurzelpfad '/'.
app = gr.mount_gradio_app(app, demo, path="/") # Gradio unter dem Wurzelpfad mounten
# Wenn du deine App lokal ausführst, kannst du FastAPI mit Uvicorn starten:
# if __name__ == "__main__":
# import uvicorn
# uvicorn.run(app, host="0.0.0.0", port=8000)
# Für Hugging Face Spaces ist der if __name__ == "__main__": Block nicht nötig,
# da Spaces Uvicorn automatisch startet und die "app"-Variable sucht.