import gradio as gr import torch from transformers import ( AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TextIteratorStreamer, StoppingCriteria, StoppingCriteriaList, ) from peft import PeftModel from threading import Thread # ── 1. Configuración ────────────────────────────────────────────────────────── BASE_MODEL_ID = "mistralai/Mistral-7B-v0.1" ADAPTER_MODEL_ID = "paulaschez/Mistral-7B-AI-Chef" SYSTEM_PROMPT = ( "You are AI Chef, an advanced culinary and nutrition assistant " "for SmartKitchen Solutions. Given a list of available ingredients " "and dietary restrictions, generate a complete, structured recipe " "with nutritional information." ) MAX_NEW_TOKENS = 100 # ── 2. Cargar modelo en 4 bits ─────────────────── print("Loading model...") bnb_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.float16, bnb_4bit_use_double_quant=True, ) tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_ID) tokenizer.pad_token = tokenizer.eos_token base_model = AutoModelForCausalLM.from_pretrained( BASE_MODEL_ID, quantization_config=bnb_config, device_map="cpu", low_cpu_mem_usage=True, ) # Cargar el adaptador LoRA encima del modelo base model = PeftModel.from_pretrained(base_model, ADAPTER_MODEL_ID) model.eval() print("Model loaded ✅") class StopOnSignal(StoppingCriteria): def __init__(self, state_dict): self.state_dict = state_dict def __call__( self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs ) -> bool: return self.state_dict.get("stop", False) # ── 3. Función de generación comparativa ───────────────────────────────────── def generate_comparative(ingredients, restrictions, user_state): user_state["stop"] = False stop_tracker = StopOnSignal(user_state) if not ingredients.strip(): yield "⚠️ Please enter some ingredients.", "" return restrictions = restrictions.strip() or "no specific restrictions" prompt = ( f"[INST] {SYSTEM_PROMPT}\n\n" f"Ingredients: {ingredients}\n" f"Dietary Restrictions: {restrictions} [/INST]" ) inputs = tokenizer(prompt, return_tensors="pt").to(model.device) gen_kwargs = dict( **inputs, max_new_tokens=MAX_NEW_TOKENS, temperature=0.3, do_sample=True, repetition_penalty=1.1, pad_token_id=tokenizer.eos_token_id, stopping_criteria=StoppingCriteriaList([stop_tracker]), ) # ── Modelo BASE ─────────────────────────────────── streamer_base = TextIteratorStreamer( tokenizer, skip_prompt=True, skip_special_tokens=True ) with model.disable_adapter(): thread = Thread( target=model.generate, kwargs={**gen_kwargs, "streamer": streamer_base} ) thread.start() base_text = "" token_count = 0 for token in streamer_base: if user_state["stop"]: break base_text += token token_count += 1 yield base_text, "⏳ Waiting for base model to finish..." thread.join() if user_state["stop"]: yield base_text + "\n\n[🛑 Stopped by user]", "🛑 Canceled." return if token_count >= MAX_NEW_TOKENS: base_text += ( f"\n\n---\n⚠️ *Response truncated at {MAX_NEW_TOKENS} tokens due to " f"CPU limitations. The full recipe would continue beyond this point.*" ) yield base_text, "⏳ Generating AI Chef response..." # ── Modelo FINE-TUNED ────────────────────────────────── streamer_ft = TextIteratorStreamer( tokenizer, skip_prompt=True, skip_special_tokens=True ) thread = Thread( target=model.generate, kwargs={**gen_kwargs, "streamer": streamer_ft} ) thread.start() ft_text = "" token_count = 0 for token in streamer_ft: if user_state["stop"]: break ft_text += token token_count += 1 yield base_text, ft_text thread.join() if user_state["stop"]: yield base_text, ft_text + "\n\n[🛑 Stopped by user]" return if token_count >= MAX_NEW_TOKENS: ft_text += ( f"\n\n---\n⚠️ *Response truncated at {MAX_NEW_TOKENS} tokens due to " f"CPU limitations. The full recipe would continue beyond this point.*" ) yield base_text, ft_text # ── 4. Interfaz Gradio ──────────────────────────────────────────────────────── with gr.Blocks() as demo: gr.Markdown("# 🧑‍🍳 AI Chef — Model Comparison") gr.Markdown( f"> ⚙️ **Hardware**: CPU basic (free tier) — " f"Generation limit: **{MAX_NEW_TOKENS} tokens** per response. " f"Responses may be truncated. A GPU environment would allow full generation." ) gr.Markdown( "Compare the performance of **Mistral-7B** without adjustment (Baseline) " "vs. the model after applying specialized **QLoRA fine-tuning** " "in recipes and nutrition." ) with gr.Row(): ing_input = gr.Textbox( label="🥕 Available ingredients", placeholder="E.g.: eggs, tomato, onion, olive oil", ) res_input = gr.Textbox( label="🥗 Dietary restrictions", placeholder="E.g.: vegan, gluten-free, high protein...", value="no specific restrictions", ) session_state = gr.State({"stop": False}) with gr.Row(): btn = gr.Button("🍳 Generate and compare", variant="primary") stop_btn = gr.Button("⏹️ Stop process", variant="stop", interactive=False) with gr.Row(): base_output = gr.Textbox( label="❌ Base Model — Mistral-7B-v0.1", lines=18, ) ft_output = gr.Textbox( label="✅ AI Chef — Fine-Tuned Model", lines=18, ) gr.Examples( examples=[ ["eggs, tomato, onion, olive oil", "no specific restrictions"], ["lentils, spinach, garlic, cumin, tomato", "vegan"], ["oats, banana, almond milk, honey", "gluten-free"], ["chicken breast, broccoli, brown rice", "high protein, low fat"], ["pasta, heavy cream, parmesan cheese", "vegetarian"], ], inputs=[ing_input, res_input], ) # ── LÓGICA DE BOTONES Y ESTADOS ── # 1. Funciones para cambiar el aspecto visual de los botones def pre_generate(): return gr.update(value="⏳ Cooking answer...", interactive=False), gr.update( interactive=True ) def post_generate(): return gr.update( value="🍳 Generate and comparare", interactive=True ), gr.update(interactive=False) def trigger_stop(user_state): user_state["stop"] = True return post_generate() # 2. Evento encadenado (.then): Actualiza botones -> Genera -> Restaura botones gen_event = ( btn.click(fn=pre_generate, outputs=[btn, stop_btn]) .then( fn=generate_comparative, inputs=[ing_input, res_input, session_state], outputs=[base_output, ft_output], ) .then(fn=post_generate, outputs=[btn, stop_btn]) ) # 3. Evento de detener: Cancela el proceso en curso y restaura los botones stop_btn.click( fn=trigger_stop, inputs=[session_state], outputs=[btn, stop_btn], cancels=[gen_event], ) demo.launch(theme=gr.themes.Soft())