Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| from transformers import ( | |
| AutoModelForCausalLM, | |
| AutoTokenizer, | |
| BitsAndBytesConfig, | |
| TextIteratorStreamer, | |
| StoppingCriteria, | |
| StoppingCriteriaList, | |
| ) | |
| from peft import PeftModel | |
| from threading import Thread | |
| # ── 1. Configuración ────────────────────────────────────────────────────────── | |
| BASE_MODEL_ID = "mistralai/Mistral-7B-v0.1" | |
| ADAPTER_MODEL_ID = "paulaschez/Mistral-7B-AI-Chef" | |
| SYSTEM_PROMPT = ( | |
| "You are AI Chef, an advanced culinary and nutrition assistant " | |
| "for SmartKitchen Solutions. Given a list of available ingredients " | |
| "and dietary restrictions, generate a complete, structured recipe " | |
| "with nutritional information." | |
| ) | |
| MAX_NEW_TOKENS = 100 | |
| # ── 2. Cargar modelo en 4 bits ─────────────────── | |
| print("Loading model...") | |
| bnb_config = BitsAndBytesConfig( | |
| load_in_4bit=True, | |
| bnb_4bit_quant_type="nf4", | |
| bnb_4bit_compute_dtype=torch.float16, | |
| bnb_4bit_use_double_quant=True, | |
| ) | |
| tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_ID) | |
| tokenizer.pad_token = tokenizer.eos_token | |
| base_model = AutoModelForCausalLM.from_pretrained( | |
| BASE_MODEL_ID, | |
| quantization_config=bnb_config, | |
| device_map="cpu", | |
| low_cpu_mem_usage=True, | |
| ) | |
| # Cargar el adaptador LoRA encima del modelo base | |
| model = PeftModel.from_pretrained(base_model, ADAPTER_MODEL_ID) | |
| model.eval() | |
| print("Model loaded ✅") | |
| class StopOnSignal(StoppingCriteria): | |
| def __init__(self, state_dict): | |
| self.state_dict = state_dict | |
| def __call__( | |
| self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs | |
| ) -> bool: | |
| return self.state_dict.get("stop", False) | |
| # ── 3. Función de generación comparativa ───────────────────────────────────── | |
| def generate_comparative(ingredients, restrictions, user_state): | |
| user_state["stop"] = False | |
| stop_tracker = StopOnSignal(user_state) | |
| if not ingredients.strip(): | |
| yield "⚠️ Please enter some ingredients.", "" | |
| return | |
| restrictions = restrictions.strip() or "no specific restrictions" | |
| prompt = ( | |
| f"<s>[INST] {SYSTEM_PROMPT}\n\n" | |
| f"Ingredients: {ingredients}\n" | |
| f"Dietary Restrictions: {restrictions} [/INST]" | |
| ) | |
| inputs = tokenizer(prompt, return_tensors="pt").to(model.device) | |
| gen_kwargs = dict( | |
| **inputs, | |
| max_new_tokens=MAX_NEW_TOKENS, | |
| temperature=0.3, | |
| do_sample=True, | |
| repetition_penalty=1.1, | |
| pad_token_id=tokenizer.eos_token_id, | |
| stopping_criteria=StoppingCriteriaList([stop_tracker]), | |
| ) | |
| # ── Modelo BASE ─────────────────────────────────── | |
| streamer_base = TextIteratorStreamer( | |
| tokenizer, skip_prompt=True, skip_special_tokens=True | |
| ) | |
| with model.disable_adapter(): | |
| thread = Thread( | |
| target=model.generate, kwargs={**gen_kwargs, "streamer": streamer_base} | |
| ) | |
| thread.start() | |
| base_text = "" | |
| token_count = 0 | |
| for token in streamer_base: | |
| if user_state["stop"]: | |
| break | |
| base_text += token | |
| token_count += 1 | |
| yield base_text, "⏳ Waiting for base model to finish..." | |
| thread.join() | |
| if user_state["stop"]: | |
| yield base_text + "\n\n[🛑 Stopped by user]", "🛑 Canceled." | |
| return | |
| if token_count >= MAX_NEW_TOKENS: | |
| base_text += ( | |
| f"\n\n---\n⚠️ *Response truncated at {MAX_NEW_TOKENS} tokens due to " | |
| f"CPU limitations. The full recipe would continue beyond this point.*" | |
| ) | |
| yield base_text, "⏳ Generating AI Chef response..." | |
| # ── Modelo FINE-TUNED ────────────────────────────────── | |
| streamer_ft = TextIteratorStreamer( | |
| tokenizer, skip_prompt=True, skip_special_tokens=True | |
| ) | |
| thread = Thread( | |
| target=model.generate, kwargs={**gen_kwargs, "streamer": streamer_ft} | |
| ) | |
| thread.start() | |
| ft_text = "" | |
| token_count = 0 | |
| for token in streamer_ft: | |
| if user_state["stop"]: | |
| break | |
| ft_text += token | |
| token_count += 1 | |
| yield base_text, ft_text | |
| thread.join() | |
| if user_state["stop"]: | |
| yield base_text, ft_text + "\n\n[🛑 Stopped by user]" | |
| return | |
| if token_count >= MAX_NEW_TOKENS: | |
| ft_text += ( | |
| f"\n\n---\n⚠️ *Response truncated at {MAX_NEW_TOKENS} tokens due to " | |
| f"CPU limitations. The full recipe would continue beyond this point.*" | |
| ) | |
| yield base_text, ft_text | |
| # ── 4. Interfaz Gradio ──────────────────────────────────────────────────────── | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# 🧑🍳 AI Chef — Model Comparison") | |
| gr.Markdown( | |
| f"> ⚙️ **Hardware**: CPU basic (free tier) — " | |
| f"Generation limit: **{MAX_NEW_TOKENS} tokens** per response. " | |
| f"Responses may be truncated. A GPU environment would allow full generation." | |
| ) | |
| gr.Markdown( | |
| "Compare the performance of **Mistral-7B** without adjustment (Baseline) " | |
| "vs. the model after applying specialized **QLoRA fine-tuning** " | |
| "in recipes and nutrition." | |
| ) | |
| with gr.Row(): | |
| ing_input = gr.Textbox( | |
| label="🥕 Available ingredients", | |
| placeholder="E.g.: eggs, tomato, onion, olive oil", | |
| ) | |
| res_input = gr.Textbox( | |
| label="🥗 Dietary restrictions", | |
| placeholder="E.g.: vegan, gluten-free, high protein...", | |
| value="no specific restrictions", | |
| ) | |
| session_state = gr.State({"stop": False}) | |
| with gr.Row(): | |
| btn = gr.Button("🍳 Generate and compare", variant="primary") | |
| stop_btn = gr.Button("⏹️ Stop process", variant="stop", interactive=False) | |
| with gr.Row(): | |
| base_output = gr.Textbox( | |
| label="❌ Base Model — Mistral-7B-v0.1", | |
| lines=18, | |
| ) | |
| ft_output = gr.Textbox( | |
| label="✅ AI Chef — Fine-Tuned Model", | |
| lines=18, | |
| ) | |
| gr.Examples( | |
| examples=[ | |
| ["eggs, tomato, onion, olive oil", "no specific restrictions"], | |
| ["lentils, spinach, garlic, cumin, tomato", "vegan"], | |
| ["oats, banana, almond milk, honey", "gluten-free"], | |
| ["chicken breast, broccoli, brown rice", "high protein, low fat"], | |
| ["pasta, heavy cream, parmesan cheese", "vegetarian"], | |
| ], | |
| inputs=[ing_input, res_input], | |
| ) | |
| # ── LÓGICA DE BOTONES Y ESTADOS ── | |
| # 1. Funciones para cambiar el aspecto visual de los botones | |
| def pre_generate(): | |
| return gr.update(value="⏳ Cooking answer...", interactive=False), gr.update( | |
| interactive=True | |
| ) | |
| def post_generate(): | |
| return gr.update( | |
| value="🍳 Generate and comparare", interactive=True | |
| ), gr.update(interactive=False) | |
| def trigger_stop(user_state): | |
| user_state["stop"] = True | |
| return post_generate() | |
| # 2. Evento encadenado (.then): Actualiza botones -> Genera -> Restaura botones | |
| gen_event = ( | |
| btn.click(fn=pre_generate, outputs=[btn, stop_btn]) | |
| .then( | |
| fn=generate_comparative, | |
| inputs=[ing_input, res_input, session_state], | |
| outputs=[base_output, ft_output], | |
| ) | |
| .then(fn=post_generate, outputs=[btn, stop_btn]) | |
| ) | |
| # 3. Evento de detener: Cancela el proceso en curso y restaura los botones | |
| stop_btn.click( | |
| fn=trigger_stop, | |
| inputs=[session_state], | |
| outputs=[btn, stop_btn], | |
| cancels=[gen_event], | |
| ) | |
| demo.launch(theme=gr.themes.Soft()) |