MixologyAI / app.py
Oren1440's picture
Update app.py
0f12821 verified
import re
import warnings
from typing import Iterable, List, Dict, Any
import requests
import torch
import gradio as gr
from PIL import Image
from datasets import load_dataset
from sentence_transformers import SentenceTransformer, util
from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
# ---------------- Global state ----------------
processed_data = None
raw_dataset = None
embedding_model = None
llm_model = None
llm_tokenizer = None
STOP = {",", "(", ")", "'"}
def remove_numbers_and_dots(text):
if isinstance(text, str):
return re.sub(r"[\d.]", "", text)
return text
def iter_chars(tok) -> Iterable[str]:
if tok is None:
return []
if isinstance(tok, list):
if not tok:
return []
s = tok[0]
else:
s = tok
return list(str(s))
def join_letters_between_punct(cleaned_tokens: List[Iterable[str]]) -> List[List[str]]:
results: List[List[str]] = []
curr_list: List[str] | None = None
building = False
buf: List[str] = []
def commit():
nonlocal building, buf, curr_list
if not building:
return
token = "".join(buf)
token = " ".join(token.split())
if token and curr_list is not None:
curr_list.append(token)
building = False
buf = []
for tok in cleaned_tokens:
for ch in iter_chars(tok):
if ch == "[":
commit()
curr_list = []
continue
if ch == "]":
commit()
if curr_list is not None:
results.append(curr_list)
curr_list = None
continue
if curr_list is None:
continue
if ch in STOP:
commit()
continue
if not building:
if ch.isalpha():
building = True
buf = [ch]
else:
continue
else:
buf.append(ch)
commit()
return results
def process_all():
print("ืžืชื—ื™ืœ ืขื™ื‘ื•ื“ ืžืงื“ื™ื ืฉืœ ื”ื“ืื˜ื” ืกื˜...")
ds = load_dataset("erwanlc/cocktails_recipe")["train"]
output = []
for i in range(len(ds)):
title, ingredients = ds[i]["title"], ds[i]["ingredients"]
cleaned = [[remove_numbers_and_dots(part) for part in item] for item in ingredients]
grouped = join_letters_between_punct(cleaned)
flat = [word for sub in grouped for word in sub]
final_tokens = [w.lower() for word in flat for w in word.split() if w.isalpha()]
output.append({"title": title, "tokens": final_tokens, "index": i})
print("ื”ืขื™ื‘ื•ื“ ื”ืžืงื“ื™ื ื”ืกืชื™ื™ื.")
return output
def inner_lists_as_lines(raw_ingredients) -> list:
if not isinstance(raw_ingredients, list):
return [str(raw_ingredients).strip()]
lines = []
for item in raw_ingredients:
if isinstance(item, list):
parts = []
for cell in item:
if cell is None:
continue
if isinstance(cell, (list, tuple)):
parts.extend(str(x).strip() for x in cell if x is not None and str(x).strip())
else:
s = str(cell).strip()
if s:
parts.append(s)
line = " ".join(parts).strip()
else:
line = str(item).strip()
if line:
lines.append(line)
return lines
def parse_llm_output(text: str) -> List[Dict[str, Any]]:
cocktails = []
blocks = [text]
for block in blocks:
if not block.strip():
continue
try:
title_match = re.search(r'\*\*Title:\*\* (.*?)\n', block, re.DOTALL)
title = title_match.group(1).strip() if title_match else "AI Generated Screwdriver"
ingredients_match = re.search(r'\*\*Ingredients:\*\*(.*?)\n\s*\*\*Recipe:\*\*', block, re.DOTALL)
if not ingredients_match:
ingredients_match = re.search(r'\*\*Ingredients:\*\*(.*)', block, re.DOTALL)
recipe_match = re.search(r'\*\*Recipe:\*\*\s*(.*)', block, re.DOTALL)
if ingredients_match or recipe_match:
ingredients_text = ingredients_match.group(1).strip() if ingredients_match else "Not specified."
ingredients = [ing.strip().lstrip('-* ').strip() for ing in ingredients_text.split('\n') if ing.strip()]
recipe = recipe_match.group(1).strip() if recipe_match else "Not specified."
cocktails.append({"title": title, "ingredients": ingredients, "recipe": recipe})
except Exception as e:
print(f"Error parsing block: {e}\nBlock content:\n{block}")
return cocktails
def generate_and_select_ai_cocktail(alcohol_query: str, fruit_query: str) -> Dict[str, Any] | None:
global llm_model, llm_tokenizer
if llm_model is None or llm_tokenizer is None:
print("LLM model not loaded, skipping local AI generation.")
return None
print("Generating AI cocktail using the local LLM...")
prompt = f"""
<bos><start_of_turn>user
You are an expert mixologist. Invent one unique cocktail. It must use '{alcohol_query}' as the base spirit and '{fruit_query}' as a key ingredient. Provide the answer in the following format exactly:
**Title:** [name]
**Ingredients:**
- [Amount] [Unit] {alcohol_query}
- [Amount] [Unit] {fruit_query}
- [other ingredients]
**Recipe:** [steps]
<end_of_turn>
<start_of_turn>model
"""
try:
inputs = llm_tokenizer(prompt, return_tensors="pt").to(llm_model.device)
outputs = llm_model.generate(
**inputs,
max_new_tokens=250,
temperature=0.8,
do_sample=True,
top_p=0.95,
top_k=50
)
generated_text = llm_tokenizer.decode(outputs[0], skip_special_tokens=False)
model_response_start = generated_text.find("<start_of_turn>model")
if model_response_start != -1:
generated_text = generated_text[model_response_start + len("<start_of_turn>model"):].strip()
except Exception as e:
print(f"Error during local LLM generation: {e}")
return None
if not generated_text:
print("Local model generated an empty response.")
return None
print("\n--- Raw AI Output ---\n", generated_text, "\n---------------------\n")
ai_cocktails = parse_llm_output(generated_text)
if not ai_cocktails:
print("Failed to parse the locally generated cocktail.")
return None
best_cocktail = ai_cocktails[0]
print(f"Selected AI cocktail (local): {best_cocktail.get('title', 'N/A')}")
return best_cocktail
def predict_cocktails(alcohol_query, fruit_query, generate_image):
global processed_data, raw_dataset, embedding_model
stage1_term_vector = embedding_model.encode(alcohol_query, convert_to_tensor=True)
stage1_results = []
for item in processed_data:
ingredient_text = " ".join(item['tokens'])
if not ingredient_text:
score = 0.0
else:
ingredients_vector = embedding_model.encode(ingredient_text, convert_to_tensor=True)
score = util.cos_sim(stage1_term_vector, ingredients_vector).item()
stage1_results.append({'score': score, 'data': item})
stage1_results.sort(key=lambda x: x['score'], reverse=True)
top_300 = stage1_results[:300]
stage2_term_vector = embedding_model.encode(fruit_query, convert_to_tensor=True)
stage2_results = []
for item in top_300:
processed_cocktail = item['data']
ingredient_text = " ".join(processed_cocktail['tokens'])
if not ingredient_text:
score = 0.0
else:
ingredients_vector = embedding_model.encode(ingredient_text, convert_to_tensor=True)
score = util.cos_sim(stage2_term_vector, ingredients_vector).item()
stage2_results.append({'score': score, 'data': processed_cocktail})
stage2_results.sort(key=lambda x: x['score'], reverse=True)
recommendations_text = ""
recommended_cocktails = []
for i, result in enumerate(stage2_results[:3]):
final_score = result['score']
cocktail_info = result['data']
title = cocktail_info['title']
original_index = cocktail_info['index']
raw = raw_dataset[original_index]
rec = {"title": title, "glass": raw.get("glass"), "garnish": raw.get("garnish"), "index": original_index}
recommended_cocktails.append(rec)
icon = "๐Ÿ†" if i == 0 else ""
recommendations_text += f"--- \n\n"
recommendations_text += f"## {i+1}. {title} {icon}\n\n"
recommendations_text += f"**Similarity Score:** {final_score:.2f}\n"
recommendations_text += f"\n**Ingredients:**\n"
for line in inner_lists_as_lines(raw.get("ingredients", [])):
recommendations_text += f"- {line}\n"
recommendations_text += f"\n**Recipe:**\n"
recipe_steps = raw.get("recipe", "No recipe provided.")
if isinstance(recipe_steps, str):
recommendations_text += f"- {recipe_steps.strip()}\n"
elif isinstance(recipe_steps, list):
recipe_string = ' '.join(str(step).strip() for step in recipe_steps)
recommendations_text += f"- {recipe_string}\n"
else:
recommendations_text += f"- {recipe_steps}\n"
ai_cocktail = generate_and_select_ai_cocktail(alcohol_query, fruit_query)
if ai_cocktail:
recommendations_text += f"\n---\n\n"
recommendations_text += f"## โœจ 4. AI-Generated Creation โœจ\n\n"
recommendations_text += f"**Title:** {ai_cocktail['title']}\n\n"
recommendations_text += f"**Ingredients:**\n"
for ingredient in ai_cocktail.get('ingredients', []):
recommendations_text += f"- {ingredient}\n"
recommendations_text += f"\n**Recipe:**\n"
recommendations_text += f"- {ai_cocktail.get('recipe', 'No recipe generated.')}\n"
recommendations_text += f"\n---\n"
image = None
if recommended_cocktails and generate_image:
print("Generating cocktail image...")
first_cocktail = recommended_cocktails[0]
title_txt = str(first_cocktail.get("title") or "cocktail")
glass_txt = str(first_cocktail.get("glass") or "appropriate")
garnish_txt = str(first_cocktail.get("garnish") or "a suitable garnish")
prompt = (
f"A photorealistic image of the cocktail {title_txt}. "
f"It is served in a {glass_txt} glass and is garnished with {garnish_txt}. "
f"High quality, food photography, professional lighting, shallow depth of field."
)
use_cuda = torch.cuda.is_available()
dtype = torch.float16 if use_cuda else torch.float32
pipe = StableDiffusionPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5",
torch_dtype=dtype,
safety_checker=None
)
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
if use_cuda:
pipe = pipe.to("cuda")
image: Image.Image = pipe(prompt).images[0]
print("Image generation complete.")
return recommendations_text, image
if __name__ == "__main__":
with warnings.catch_warnings():
warnings.simplefilter("ignore")
processed_data = process_all()
raw_dataset = load_dataset("erwanlc/cocktails_recipe")["train"]
embedding_model_name = 'paraphrase-MiniLM-L6-v2'
print(f"\nื˜ื•ืขืŸ ืืช ืžื•ื“ืœ ื”ื•ื•ืงื˜ื•ืจื™ื: {embedding_model_name}...")
embedding_model = SentenceTransformer(embedding_model_name)
print("ืžื•ื“ืœ ื”ื•ื•ืงื˜ื•ืจื™ื ืžื•ื›ืŸ.")
llm_model_name = "google/gemma-2b-it"
print(f"\nื˜ื•ืขืŸ ืืช ืžื•ื“ืœ ื”ืฉืคื”: {llm_model_name}...")
try:
# ื”ื’ื“ืจื•ืช ืœื“ื—ื™ืกืช ื”ืžื•ื“ืœ ื›ื“ื™ ืฉื™ืฆืจื•ืš ืคื—ื•ืช ื–ื™ื›ืจื•ืŸ
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.bfloat16
)
llm_tokenizer = AutoTokenizer.from_pretrained(llm_model_name)
llm_model = AutoModelForCausalLM.from_pretrained(
llm_model_name,
quantization_config=quantization_config,
device_map="auto"
)
print("ืžื•ื“ืœ ื”ืฉืคื” ืžื•ื›ืŸ.")
except Exception as e:
print(f"Error loading LLM: {e}")
llm_model = None
llm_tokenizer = None
print("ื”ืืคืœื™ืงืฆื™ื” ืชืžืฉื™ืš ืœืจื•ืฅ ืœืœื ืงื•ืงื˜ื™ื™ืœื™ื ื—ื“ืฉื™ื ืžื”ืžื•ื“ืœ ื”ืžืงื•ืžื™.")
alcohol_options = ["Vodka", "Whiskey", "Gin", "Tequila", "Wine"]
ingredient_options = ["Orange Juice", "Apple", "Cranberry", "Lemon", "Thyme"]
with gr.Blocks() as demo:
gr.Markdown(
"""
# ๐Ÿธ Cocktail Creation Assistant ๐Ÿธ
Enter a base alcohol and a main ingredient to get three cocktail recipes from the dataset,
plus one unique recipe created by the local AI model.
Use the buttons for quick suggestions!
"""
)
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("### 1. Choose Your Ingredients")
alcohol_input = gr.Textbox(label="Base Alcohol", placeholder="e.g., Gin, Rum, Vodka...")
with gr.Row():
for alcohol in alcohol_options:
btn = gr.Button(alcohol)
btn.click(lambda val=alcohol: val, outputs=alcohol_input)
ingredient_input = gr.Textbox(label="Main Ingredient", placeholder="e.g., Orange, Lime, Mint...")
with gr.Row():
for ingredient in ingredient_options:
btn = gr.Button(ingredient)
btn.click(lambda val=ingredient: val, outputs=ingredient_input)
generate_image_checkbox = gr.Checkbox(label="Generate image for top result?", value=False)
submit_btn = gr.Button("๐Ÿน Get My Cocktails!", variant="primary")
with gr.Column(scale=2):
gr.Markdown("### 2. Your Recommendations")
output_markdown = gr.Markdown(label="Cocktail Recommendations")
output_image = gr.Image(label="Generated Cocktail Photo")
submit_btn.click(
fn=predict_cocktails,
inputs=[alcohol_input, ingredient_input, generate_image_checkbox],
outputs=[output_markdown, output_image]
)
demo.launch(debug=True)