paulaschez's picture
fix
6f07175
import gradio as gr
import torch
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
BitsAndBytesConfig,
TextIteratorStreamer,
StoppingCriteria,
StoppingCriteriaList,
)
from peft import PeftModel
from threading import Thread
# ── 1. Configuración ──────────────────────────────────────────────────────────
BASE_MODEL_ID = "mistralai/Mistral-7B-v0.1"
ADAPTER_MODEL_ID = "paulaschez/Mistral-7B-AI-Chef"
SYSTEM_PROMPT = (
"You are AI Chef, an advanced culinary and nutrition assistant "
"for SmartKitchen Solutions. Given a list of available ingredients "
"and dietary restrictions, generate a complete, structured recipe "
"with nutritional information."
)
MAX_NEW_TOKENS = 100
# ── 2. Cargar modelo en 4 bits ───────────────────
print("Loading model...")
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_use_double_quant=True,
)
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_ID)
tokenizer.pad_token = tokenizer.eos_token
base_model = AutoModelForCausalLM.from_pretrained(
BASE_MODEL_ID,
quantization_config=bnb_config,
device_map="cpu",
low_cpu_mem_usage=True,
)
# Cargar el adaptador LoRA encima del modelo base
model = PeftModel.from_pretrained(base_model, ADAPTER_MODEL_ID)
model.eval()
print("Model loaded ✅")
class StopOnSignal(StoppingCriteria):
def __init__(self, state_dict):
self.state_dict = state_dict
def __call__(
self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs
) -> bool:
return self.state_dict.get("stop", False)
# ── 3. Función de generación comparativa ─────────────────────────────────────
def generate_comparative(ingredients, restrictions, user_state):
user_state["stop"] = False
stop_tracker = StopOnSignal(user_state)
if not ingredients.strip():
yield "⚠️ Please enter some ingredients.", ""
return
restrictions = restrictions.strip() or "no specific restrictions"
prompt = (
f"<s>[INST] {SYSTEM_PROMPT}\n\n"
f"Ingredients: {ingredients}\n"
f"Dietary Restrictions: {restrictions} [/INST]"
)
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
gen_kwargs = dict(
**inputs,
max_new_tokens=MAX_NEW_TOKENS,
temperature=0.3,
do_sample=True,
repetition_penalty=1.1,
pad_token_id=tokenizer.eos_token_id,
stopping_criteria=StoppingCriteriaList([stop_tracker]),
)
# ── Modelo BASE ───────────────────────────────────
streamer_base = TextIteratorStreamer(
tokenizer, skip_prompt=True, skip_special_tokens=True
)
with model.disable_adapter():
thread = Thread(
target=model.generate, kwargs={**gen_kwargs, "streamer": streamer_base}
)
thread.start()
base_text = ""
token_count = 0
for token in streamer_base:
if user_state["stop"]:
break
base_text += token
token_count += 1
yield base_text, "⏳ Waiting for base model to finish..."
thread.join()
if user_state["stop"]:
yield base_text + "\n\n[🛑 Stopped by user]", "🛑 Canceled."
return
if token_count >= MAX_NEW_TOKENS:
base_text += (
f"\n\n---\n⚠️ *Response truncated at {MAX_NEW_TOKENS} tokens due to "
f"CPU limitations. The full recipe would continue beyond this point.*"
)
yield base_text, "⏳ Generating AI Chef response..."
# ── Modelo FINE-TUNED ──────────────────────────────────
streamer_ft = TextIteratorStreamer(
tokenizer, skip_prompt=True, skip_special_tokens=True
)
thread = Thread(
target=model.generate, kwargs={**gen_kwargs, "streamer": streamer_ft}
)
thread.start()
ft_text = ""
token_count = 0
for token in streamer_ft:
if user_state["stop"]:
break
ft_text += token
token_count += 1
yield base_text, ft_text
thread.join()
if user_state["stop"]:
yield base_text, ft_text + "\n\n[🛑 Stopped by user]"
return
if token_count >= MAX_NEW_TOKENS:
ft_text += (
f"\n\n---\n⚠️ *Response truncated at {MAX_NEW_TOKENS} tokens due to "
f"CPU limitations. The full recipe would continue beyond this point.*"
)
yield base_text, ft_text
# ── 4. Interfaz Gradio ────────────────────────────────────────────────────────
with gr.Blocks() as demo:
gr.Markdown("# 🧑‍🍳 AI Chef — Model Comparison")
gr.Markdown(
f"> ⚙️ **Hardware**: CPU basic (free tier) — "
f"Generation limit: **{MAX_NEW_TOKENS} tokens** per response. "
f"Responses may be truncated. A GPU environment would allow full generation."
)
gr.Markdown(
"Compare the performance of **Mistral-7B** without adjustment (Baseline) "
"vs. the model after applying specialized **QLoRA fine-tuning** "
"in recipes and nutrition."
)
with gr.Row():
ing_input = gr.Textbox(
label="🥕 Available ingredients",
placeholder="E.g.: eggs, tomato, onion, olive oil",
)
res_input = gr.Textbox(
label="🥗 Dietary restrictions",
placeholder="E.g.: vegan, gluten-free, high protein...",
value="no specific restrictions",
)
session_state = gr.State({"stop": False})
with gr.Row():
btn = gr.Button("🍳 Generate and compare", variant="primary")
stop_btn = gr.Button("⏹️ Stop process", variant="stop", interactive=False)
with gr.Row():
base_output = gr.Textbox(
label="❌ Base Model — Mistral-7B-v0.1",
lines=18,
)
ft_output = gr.Textbox(
label="✅ AI Chef — Fine-Tuned Model",
lines=18,
)
gr.Examples(
examples=[
["eggs, tomato, onion, olive oil", "no specific restrictions"],
["lentils, spinach, garlic, cumin, tomato", "vegan"],
["oats, banana, almond milk, honey", "gluten-free"],
["chicken breast, broccoli, brown rice", "high protein, low fat"],
["pasta, heavy cream, parmesan cheese", "vegetarian"],
],
inputs=[ing_input, res_input],
)
# ── LÓGICA DE BOTONES Y ESTADOS ──
# 1. Funciones para cambiar el aspecto visual de los botones
def pre_generate():
return gr.update(value="⏳ Cooking answer...", interactive=False), gr.update(
interactive=True
)
def post_generate():
return gr.update(
value="🍳 Generate and comparare", interactive=True
), gr.update(interactive=False)
def trigger_stop(user_state):
user_state["stop"] = True
return post_generate()
# 2. Evento encadenado (.then): Actualiza botones -> Genera -> Restaura botones
gen_event = (
btn.click(fn=pre_generate, outputs=[btn, stop_btn])
.then(
fn=generate_comparative,
inputs=[ing_input, res_input, session_state],
outputs=[base_output, ft_output],
)
.then(fn=post_generate, outputs=[btn, stop_btn])
)
# 3. Evento de detener: Cancela el proceso en curso y restaura los botones
stop_btn.click(
fn=trigger_stop,
inputs=[session_state],
outputs=[btn, stop_btn],
cancels=[gen_event],
)
demo.launch(theme=gr.themes.Soft())