File size: 3,792 Bytes
5ae758d
 
3acb6f9
 
 
5ae758d
 
 
 
 
 
3bff36c
5ae758d
3acb6f9
5ae758d
 
3acb6f9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5ae758d
 
3acb6f9
 
5ae758d
 
 
3bff36c
5ae758d
 
3bff36c
 
 
3acb6f9
 
 
 
 
 
 
 
 
 
5ae758d
3acb6f9
3bff36c
 
 
3acb6f9
 
3bff36c
3acb6f9
5ae758d
3acb6f9
 
 
 
 
 
 
 
 
 
 
 
 
5ae758d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
import os
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel, PeftConfig

MAX_NEW_TOKENS = 100
TEMPERATURE = 0.5
TOP_P = 0.95
TOP_K = 50
REPETITION_PENALTY = 1.05
SPECIAL_TOKEN = "->:"

HF_TOKEN = os.getenv('HF_TOKEN')

def load_model():
    base_model_id = "meta-llama/Llama-2-7b-hf"
    peft_model_id = "somosnlp-hackathon-2025/Llama-2-7b-hf-lora-refranes"

    config = PeftConfig.from_pretrained(peft_model_id)

    base_model = AutoModelForCausalLM.from_pretrained(
        base_model_id,
        torch_dtype="auto",
        device_map="auto",
        token=HF_TOKEN
    )

    model = PeftModel.from_pretrained(base_model, peft_model_id)

    tokenizer = AutoTokenizer.from_pretrained(base_model_id)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    return model, tokenizer

model = None
tokenizer = None

def generate_response(input_text, max_tokens, temperature, top_p, repetition_penalty):
    global model, tokenizer

    if model is None or tokenizer is None:
        model, tokenizer = load_model()

    inputs = tokenizer(input_text + SPECIAL_TOKEN, return_tensors="pt").to(model.device)

    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=max_tokens,
            temperature=temperature,
            do_sample=True,
            top_p=top_p,
            top_k=TOP_K,
            repetition_penalty=repetition_penalty
        )

    full_response = tokenizer.decode(outputs[0], skip_special_tokens=True)

    if SPECIAL_TOKEN in full_response:
        response_parts = full_response.split(SPECIAL_TOKEN, 1)
        if len(response_parts) > 1:
            return response_parts[1].strip()

    return full_response.strip()

def chat_interface(message, history, system_message, max_tokens, temperature, top_p, repetition_penalty):
    prompt = f"{message}"
    if system_message:
        prompt = f"{system_message}\n{message}"
        
    response = generate_response(
        prompt, 
        max_tokens, 
        temperature, 
        top_p, 
        repetition_penalty
    )
    return response

demo = gr.ChatInterface(
    chat_interface,
    title="Sabiduría Popular - Refranes",
    description="Esta aplicación explica el significado de refranes en español utilizando un modelo de lenguaje. Escribe un refrán y el modelo te explicará su significado.",
    examples=[
        ["A caballo regalado no le mires el diente"],
        ["Más vale pájaro en mano que ciento volando"],
        ["Quien a buen árbol se arrima, buena sombra le cobija"],
        ["No por mucho madrugar amanece más temprano"]
    ],
    additional_inputs=[
        gr.Textbox(
            value="Eres un experto en sabiduría popular española. Tu tarea es explicar el significado de refranes en español de manera clara y concisa.",
            label="System message"
        ),
        gr.Slider(
            minimum=1,
            maximum=500,
            value=MAX_NEW_TOKENS,
            step=1,
            label="Max new tokens"
        ),
        gr.Slider(
            minimum=0.1,
            maximum=2.0,
            value=TEMPERATURE,
            step=0.1,
            label="Temperature"
        ),
        gr.Slider(
            minimum=0.1,
            maximum=1.0,
            value=TOP_P,
            step=0.05,
            label="Top-p (nucleus sampling)"
        ),
        gr.Slider(
            minimum=1.0,
            maximum=2.0,
            value=REPETITION_PENALTY,
            step=0.05,
            label="Repetition penalty"
        ),
    ],
    theme="soft"
)

if __name__ == "__main__":
    print("Iniciando la aplicación. El modelo se cargará con la primera consulta.")
    demo.launch()