Spaces:
Sleeping
Sleeping
| import re | |
| import warnings | |
| from typing import Iterable, List, Dict, Any | |
| import requests | |
| import torch | |
| import gradio as gr | |
| from PIL import Image | |
| from datasets import load_dataset | |
| from sentence_transformers import SentenceTransformer, util | |
| from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler | |
| from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig | |
| # ---------------- Global state ---------------- | |
| processed_data = None | |
| raw_dataset = None | |
| embedding_model = None | |
| llm_model = None | |
| llm_tokenizer = None | |
| STOP = {",", "(", ")", "'"} | |
| def remove_numbers_and_dots(text): | |
| if isinstance(text, str): | |
| return re.sub(r"[\d.]", "", text) | |
| return text | |
| def iter_chars(tok) -> Iterable[str]: | |
| if tok is None: | |
| return [] | |
| if isinstance(tok, list): | |
| if not tok: | |
| return [] | |
| s = tok[0] | |
| else: | |
| s = tok | |
| return list(str(s)) | |
| def join_letters_between_punct(cleaned_tokens: List[Iterable[str]]) -> List[List[str]]: | |
| results: List[List[str]] = [] | |
| curr_list: List[str] | None = None | |
| building = False | |
| buf: List[str] = [] | |
| def commit(): | |
| nonlocal building, buf, curr_list | |
| if not building: | |
| return | |
| token = "".join(buf) | |
| token = " ".join(token.split()) | |
| if token and curr_list is not None: | |
| curr_list.append(token) | |
| building = False | |
| buf = [] | |
| for tok in cleaned_tokens: | |
| for ch in iter_chars(tok): | |
| if ch == "[": | |
| commit() | |
| curr_list = [] | |
| continue | |
| if ch == "]": | |
| commit() | |
| if curr_list is not None: | |
| results.append(curr_list) | |
| curr_list = None | |
| continue | |
| if curr_list is None: | |
| continue | |
| if ch in STOP: | |
| commit() | |
| continue | |
| if not building: | |
| if ch.isalpha(): | |
| building = True | |
| buf = [ch] | |
| else: | |
| continue | |
| else: | |
| buf.append(ch) | |
| commit() | |
| return results | |
| def process_all(): | |
| print("ืืชืืื ืขืืืื ืืงืืื ืฉื ืืืืื ืกื...") | |
| ds = load_dataset("erwanlc/cocktails_recipe")["train"] | |
| output = [] | |
| for i in range(len(ds)): | |
| title, ingredients = ds[i]["title"], ds[i]["ingredients"] | |
| cleaned = [[remove_numbers_and_dots(part) for part in item] for item in ingredients] | |
| grouped = join_letters_between_punct(cleaned) | |
| flat = [word for sub in grouped for word in sub] | |
| final_tokens = [w.lower() for word in flat for w in word.split() if w.isalpha()] | |
| output.append({"title": title, "tokens": final_tokens, "index": i}) | |
| print("ืืขืืืื ืืืงืืื ืืกืชืืื.") | |
| return output | |
| def inner_lists_as_lines(raw_ingredients) -> list: | |
| if not isinstance(raw_ingredients, list): | |
| return [str(raw_ingredients).strip()] | |
| lines = [] | |
| for item in raw_ingredients: | |
| if isinstance(item, list): | |
| parts = [] | |
| for cell in item: | |
| if cell is None: | |
| continue | |
| if isinstance(cell, (list, tuple)): | |
| parts.extend(str(x).strip() for x in cell if x is not None and str(x).strip()) | |
| else: | |
| s = str(cell).strip() | |
| if s: | |
| parts.append(s) | |
| line = " ".join(parts).strip() | |
| else: | |
| line = str(item).strip() | |
| if line: | |
| lines.append(line) | |
| return lines | |
| def parse_llm_output(text: str) -> List[Dict[str, Any]]: | |
| cocktails = [] | |
| blocks = [text] | |
| for block in blocks: | |
| if not block.strip(): | |
| continue | |
| try: | |
| title_match = re.search(r'\*\*Title:\*\* (.*?)\n', block, re.DOTALL) | |
| title = title_match.group(1).strip() if title_match else "AI Generated Screwdriver" | |
| ingredients_match = re.search(r'\*\*Ingredients:\*\*(.*?)\n\s*\*\*Recipe:\*\*', block, re.DOTALL) | |
| if not ingredients_match: | |
| ingredients_match = re.search(r'\*\*Ingredients:\*\*(.*)', block, re.DOTALL) | |
| recipe_match = re.search(r'\*\*Recipe:\*\*\s*(.*)', block, re.DOTALL) | |
| if ingredients_match or recipe_match: | |
| ingredients_text = ingredients_match.group(1).strip() if ingredients_match else "Not specified." | |
| ingredients = [ing.strip().lstrip('-* ').strip() for ing in ingredients_text.split('\n') if ing.strip()] | |
| recipe = recipe_match.group(1).strip() if recipe_match else "Not specified." | |
| cocktails.append({"title": title, "ingredients": ingredients, "recipe": recipe}) | |
| except Exception as e: | |
| print(f"Error parsing block: {e}\nBlock content:\n{block}") | |
| return cocktails | |
| def generate_and_select_ai_cocktail(alcohol_query: str, fruit_query: str) -> Dict[str, Any] | None: | |
| global llm_model, llm_tokenizer | |
| if llm_model is None or llm_tokenizer is None: | |
| print("LLM model not loaded, skipping local AI generation.") | |
| return None | |
| print("Generating AI cocktail using the local LLM...") | |
| prompt = f""" | |
| <bos><start_of_turn>user | |
| You are an expert mixologist. Invent one unique cocktail. It must use '{alcohol_query}' as the base spirit and '{fruit_query}' as a key ingredient. Provide the answer in the following format exactly: | |
| **Title:** [name] | |
| **Ingredients:** | |
| - [Amount] [Unit] {alcohol_query} | |
| - [Amount] [Unit] {fruit_query} | |
| - [other ingredients] | |
| **Recipe:** [steps] | |
| <end_of_turn> | |
| <start_of_turn>model | |
| """ | |
| try: | |
| inputs = llm_tokenizer(prompt, return_tensors="pt").to(llm_model.device) | |
| outputs = llm_model.generate( | |
| **inputs, | |
| max_new_tokens=250, | |
| temperature=0.8, | |
| do_sample=True, | |
| top_p=0.95, | |
| top_k=50 | |
| ) | |
| generated_text = llm_tokenizer.decode(outputs[0], skip_special_tokens=False) | |
| model_response_start = generated_text.find("<start_of_turn>model") | |
| if model_response_start != -1: | |
| generated_text = generated_text[model_response_start + len("<start_of_turn>model"):].strip() | |
| except Exception as e: | |
| print(f"Error during local LLM generation: {e}") | |
| return None | |
| if not generated_text: | |
| print("Local model generated an empty response.") | |
| return None | |
| print("\n--- Raw AI Output ---\n", generated_text, "\n---------------------\n") | |
| ai_cocktails = parse_llm_output(generated_text) | |
| if not ai_cocktails: | |
| print("Failed to parse the locally generated cocktail.") | |
| return None | |
| best_cocktail = ai_cocktails[0] | |
| print(f"Selected AI cocktail (local): {best_cocktail.get('title', 'N/A')}") | |
| return best_cocktail | |
| def predict_cocktails(alcohol_query, fruit_query, generate_image): | |
| global processed_data, raw_dataset, embedding_model | |
| stage1_term_vector = embedding_model.encode(alcohol_query, convert_to_tensor=True) | |
| stage1_results = [] | |
| for item in processed_data: | |
| ingredient_text = " ".join(item['tokens']) | |
| if not ingredient_text: | |
| score = 0.0 | |
| else: | |
| ingredients_vector = embedding_model.encode(ingredient_text, convert_to_tensor=True) | |
| score = util.cos_sim(stage1_term_vector, ingredients_vector).item() | |
| stage1_results.append({'score': score, 'data': item}) | |
| stage1_results.sort(key=lambda x: x['score'], reverse=True) | |
| top_300 = stage1_results[:300] | |
| stage2_term_vector = embedding_model.encode(fruit_query, convert_to_tensor=True) | |
| stage2_results = [] | |
| for item in top_300: | |
| processed_cocktail = item['data'] | |
| ingredient_text = " ".join(processed_cocktail['tokens']) | |
| if not ingredient_text: | |
| score = 0.0 | |
| else: | |
| ingredients_vector = embedding_model.encode(ingredient_text, convert_to_tensor=True) | |
| score = util.cos_sim(stage2_term_vector, ingredients_vector).item() | |
| stage2_results.append({'score': score, 'data': processed_cocktail}) | |
| stage2_results.sort(key=lambda x: x['score'], reverse=True) | |
| recommendations_text = "" | |
| recommended_cocktails = [] | |
| for i, result in enumerate(stage2_results[:3]): | |
| final_score = result['score'] | |
| cocktail_info = result['data'] | |
| title = cocktail_info['title'] | |
| original_index = cocktail_info['index'] | |
| raw = raw_dataset[original_index] | |
| rec = {"title": title, "glass": raw.get("glass"), "garnish": raw.get("garnish"), "index": original_index} | |
| recommended_cocktails.append(rec) | |
| icon = "๐" if i == 0 else "" | |
| recommendations_text += f"--- \n\n" | |
| recommendations_text += f"## {i+1}. {title} {icon}\n\n" | |
| recommendations_text += f"**Similarity Score:** {final_score:.2f}\n" | |
| recommendations_text += f"\n**Ingredients:**\n" | |
| for line in inner_lists_as_lines(raw.get("ingredients", [])): | |
| recommendations_text += f"- {line}\n" | |
| recommendations_text += f"\n**Recipe:**\n" | |
| recipe_steps = raw.get("recipe", "No recipe provided.") | |
| if isinstance(recipe_steps, str): | |
| recommendations_text += f"- {recipe_steps.strip()}\n" | |
| elif isinstance(recipe_steps, list): | |
| recipe_string = ' '.join(str(step).strip() for step in recipe_steps) | |
| recommendations_text += f"- {recipe_string}\n" | |
| else: | |
| recommendations_text += f"- {recipe_steps}\n" | |
| ai_cocktail = generate_and_select_ai_cocktail(alcohol_query, fruit_query) | |
| if ai_cocktail: | |
| recommendations_text += f"\n---\n\n" | |
| recommendations_text += f"## โจ 4. AI-Generated Creation โจ\n\n" | |
| recommendations_text += f"**Title:** {ai_cocktail['title']}\n\n" | |
| recommendations_text += f"**Ingredients:**\n" | |
| for ingredient in ai_cocktail.get('ingredients', []): | |
| recommendations_text += f"- {ingredient}\n" | |
| recommendations_text += f"\n**Recipe:**\n" | |
| recommendations_text += f"- {ai_cocktail.get('recipe', 'No recipe generated.')}\n" | |
| recommendations_text += f"\n---\n" | |
| image = None | |
| if recommended_cocktails and generate_image: | |
| print("Generating cocktail image...") | |
| first_cocktail = recommended_cocktails[0] | |
| title_txt = str(first_cocktail.get("title") or "cocktail") | |
| glass_txt = str(first_cocktail.get("glass") or "appropriate") | |
| garnish_txt = str(first_cocktail.get("garnish") or "a suitable garnish") | |
| prompt = ( | |
| f"A photorealistic image of the cocktail {title_txt}. " | |
| f"It is served in a {glass_txt} glass and is garnished with {garnish_txt}. " | |
| f"High quality, food photography, professional lighting, shallow depth of field." | |
| ) | |
| use_cuda = torch.cuda.is_available() | |
| dtype = torch.float16 if use_cuda else torch.float32 | |
| pipe = StableDiffusionPipeline.from_pretrained( | |
| "runwayml/stable-diffusion-v1-5", | |
| torch_dtype=dtype, | |
| safety_checker=None | |
| ) | |
| pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config) | |
| if use_cuda: | |
| pipe = pipe.to("cuda") | |
| image: Image.Image = pipe(prompt).images[0] | |
| print("Image generation complete.") | |
| return recommendations_text, image | |
| if __name__ == "__main__": | |
| with warnings.catch_warnings(): | |
| warnings.simplefilter("ignore") | |
| processed_data = process_all() | |
| raw_dataset = load_dataset("erwanlc/cocktails_recipe")["train"] | |
| embedding_model_name = 'paraphrase-MiniLM-L6-v2' | |
| print(f"\nืืืขื ืืช ืืืื ืืืืงืืืจืื: {embedding_model_name}...") | |
| embedding_model = SentenceTransformer(embedding_model_name) | |
| print("ืืืื ืืืืงืืืจืื ืืืื.") | |
| llm_model_name = "google/gemma-2b-it" | |
| print(f"\nืืืขื ืืช ืืืื ืืฉืคื: {llm_model_name}...") | |
| try: | |
| # ืืืืจืืช ืืืืืกืช ืืืืื ืืื ืฉืืฆืจืื ืคืืืช ืืืืจืื | |
| quantization_config = BitsAndBytesConfig( | |
| load_in_4bit=True, | |
| bnb_4bit_compute_dtype=torch.bfloat16 | |
| ) | |
| llm_tokenizer = AutoTokenizer.from_pretrained(llm_model_name) | |
| llm_model = AutoModelForCausalLM.from_pretrained( | |
| llm_model_name, | |
| quantization_config=quantization_config, | |
| device_map="auto" | |
| ) | |
| print("ืืืื ืืฉืคื ืืืื.") | |
| except Exception as e: | |
| print(f"Error loading LLM: {e}") | |
| llm_model = None | |
| llm_tokenizer = None | |
| print("ืืืคืืืงืฆืื ืชืืฉืื ืืจืืฅ ืืื ืงืืงืืืืืื ืืืฉืื ืืืืืื ืืืงืืื.") | |
| alcohol_options = ["Vodka", "Whiskey", "Gin", "Tequila", "Wine"] | |
| ingredient_options = ["Orange Juice", "Apple", "Cranberry", "Lemon", "Thyme"] | |
| with gr.Blocks() as demo: | |
| gr.Markdown( | |
| """ | |
| # ๐ธ Cocktail Creation Assistant ๐ธ | |
| Enter a base alcohol and a main ingredient to get three cocktail recipes from the dataset, | |
| plus one unique recipe created by the local AI model. | |
| Use the buttons for quick suggestions! | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.Markdown("### 1. Choose Your Ingredients") | |
| alcohol_input = gr.Textbox(label="Base Alcohol", placeholder="e.g., Gin, Rum, Vodka...") | |
| with gr.Row(): | |
| for alcohol in alcohol_options: | |
| btn = gr.Button(alcohol) | |
| btn.click(lambda val=alcohol: val, outputs=alcohol_input) | |
| ingredient_input = gr.Textbox(label="Main Ingredient", placeholder="e.g., Orange, Lime, Mint...") | |
| with gr.Row(): | |
| for ingredient in ingredient_options: | |
| btn = gr.Button(ingredient) | |
| btn.click(lambda val=ingredient: val, outputs=ingredient_input) | |
| generate_image_checkbox = gr.Checkbox(label="Generate image for top result?", value=False) | |
| submit_btn = gr.Button("๐น Get My Cocktails!", variant="primary") | |
| with gr.Column(scale=2): | |
| gr.Markdown("### 2. Your Recommendations") | |
| output_markdown = gr.Markdown(label="Cocktail Recommendations") | |
| output_image = gr.Image(label="Generated Cocktail Photo") | |
| submit_btn.click( | |
| fn=predict_cocktails, | |
| inputs=[alcohol_input, ingredient_input, generate_image_checkbox], | |
| outputs=[output_markdown, output_image] | |
| ) | |
| demo.launch(debug=True) |