File size: 3,792 Bytes
5ae758d 3acb6f9 5ae758d 3bff36c 5ae758d 3acb6f9 5ae758d 3acb6f9 5ae758d 3acb6f9 5ae758d 3bff36c 5ae758d 3bff36c 3acb6f9 5ae758d 3acb6f9 3bff36c 3acb6f9 3bff36c 3acb6f9 5ae758d 3acb6f9 5ae758d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 |
import os
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel, PeftConfig
MAX_NEW_TOKENS = 100
TEMPERATURE = 0.5
TOP_P = 0.95
TOP_K = 50
REPETITION_PENALTY = 1.05
SPECIAL_TOKEN = "->:"
HF_TOKEN = os.getenv('HF_TOKEN')
def load_model():
base_model_id = "meta-llama/Llama-2-7b-hf"
peft_model_id = "somosnlp-hackathon-2025/Llama-2-7b-hf-lora-refranes"
config = PeftConfig.from_pretrained(peft_model_id)
base_model = AutoModelForCausalLM.from_pretrained(
base_model_id,
torch_dtype="auto",
device_map="auto",
token=HF_TOKEN
)
model = PeftModel.from_pretrained(base_model, peft_model_id)
tokenizer = AutoTokenizer.from_pretrained(base_model_id)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
return model, tokenizer
model = None
tokenizer = None
def generate_response(input_text, max_tokens, temperature, top_p, repetition_penalty):
global model, tokenizer
if model is None or tokenizer is None:
model, tokenizer = load_model()
inputs = tokenizer(input_text + SPECIAL_TOKEN, return_tensors="pt").to(model.device)
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=max_tokens,
temperature=temperature,
do_sample=True,
top_p=top_p,
top_k=TOP_K,
repetition_penalty=repetition_penalty
)
full_response = tokenizer.decode(outputs[0], skip_special_tokens=True)
if SPECIAL_TOKEN in full_response:
response_parts = full_response.split(SPECIAL_TOKEN, 1)
if len(response_parts) > 1:
return response_parts[1].strip()
return full_response.strip()
def chat_interface(message, history, system_message, max_tokens, temperature, top_p, repetition_penalty):
prompt = f"{message}"
if system_message:
prompt = f"{system_message}\n{message}"
response = generate_response(
prompt,
max_tokens,
temperature,
top_p,
repetition_penalty
)
return response
demo = gr.ChatInterface(
chat_interface,
title="Sabiduría Popular - Refranes",
description="Esta aplicación explica el significado de refranes en español utilizando un modelo de lenguaje. Escribe un refrán y el modelo te explicará su significado.",
examples=[
["A caballo regalado no le mires el diente"],
["Más vale pájaro en mano que ciento volando"],
["Quien a buen árbol se arrima, buena sombra le cobija"],
["No por mucho madrugar amanece más temprano"]
],
additional_inputs=[
gr.Textbox(
value="Eres un experto en sabiduría popular española. Tu tarea es explicar el significado de refranes en español de manera clara y concisa.",
label="System message"
),
gr.Slider(
minimum=1,
maximum=500,
value=MAX_NEW_TOKENS,
step=1,
label="Max new tokens"
),
gr.Slider(
minimum=0.1,
maximum=2.0,
value=TEMPERATURE,
step=0.1,
label="Temperature"
),
gr.Slider(
minimum=0.1,
maximum=1.0,
value=TOP_P,
step=0.05,
label="Top-p (nucleus sampling)"
),
gr.Slider(
minimum=1.0,
maximum=2.0,
value=REPETITION_PENALTY,
step=0.05,
label="Repetition penalty"
),
],
theme="soft"
)
if __name__ == "__main__":
print("Iniciando la aplicación. El modelo se cargará con la primera consulta.")
demo.launch() |