DavidBazaldua commited on
Commit
d5b6e77
·
verified ·
1 Parent(s): b84dcec

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +196 -62
app.py CHANGED
@@ -1,70 +1,204 @@
 
 
1
  import gradio as gr
2
- from huggingface_hub import InferenceClient
3
-
4
-
5
- def respond(
6
- message,
7
- history: list[dict[str, str]],
8
- system_message,
9
- max_tokens,
10
- temperature,
11
- top_p,
12
- hf_token: gr.OAuthToken,
13
- ):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  """
15
- For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
 
 
 
16
  """
17
- client = InferenceClient(token=hf_token.token, model="openai/gpt-oss-20b")
18
-
19
- messages = [{"role": "system", "content": system_message}]
20
-
21
- messages.extend(history)
22
-
23
- messages.append({"role": "user", "content": message})
24
-
25
- response = ""
26
-
27
- for message in client.chat_completion(
28
- messages,
29
- max_tokens=max_tokens,
30
- stream=True,
31
- temperature=temperature,
32
- top_p=top_p,
33
- ):
34
- choices = message.choices
35
- token = ""
36
- if len(choices) and choices[0].delta.content:
37
- token = choices[0].delta.content
38
-
39
- response += token
40
- yield response
41
-
42
-
43
- """
44
- For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
45
- """
46
- chatbot = gr.ChatInterface(
47
- respond,
48
- type="messages",
49
- additional_inputs=[
50
- gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
51
- gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
52
- gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
53
- gr.Slider(
54
- minimum=0.1,
55
- maximum=1.0,
56
- value=0.95,
57
- step=0.05,
58
- label="Top-p (nucleus sampling)",
59
- ),
60
- ],
61
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
 
63
- with gr.Blocks() as demo:
64
- with gr.Sidebar():
65
- gr.LoginButton()
66
- chatbot.render()
 
 
 
 
 
 
67
 
 
 
 
 
 
68
 
 
69
  if __name__ == "__main__":
70
  demo.launch()
 
1
+ import os
2
+ import torch
3
  import gradio as gr
4
+ from transformers import AutoModelForCausalLM, AutoTokenizer
5
+
6
+ # 1. Configuración básica
7
+ MODEL_ID = "DavidBazaldua/llama-iris-finetuned"
8
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
9
+ DTYPE = torch.float16 if DEVICE == "cuda" else torch.float32
10
+
11
+ # 2. Carga de tokenizer y modelo
12
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
13
+
14
+ model = AutoModelForCausalLM.from_pretrained(
15
+ MODEL_ID,
16
+ torch_dtype=DTYPE,
17
+ device_map="auto" if DEVICE == "cuda" else None
18
+ )
19
+ if DEVICE == "cpu":
20
+ model.to(DEVICE)
21
+
22
+ # 3. System prompt por defecto (lo puedes personalizar)
23
+ DEFAULT_SYSTEM_PROMPT = (
24
+ "You are Iris, a helpful, kind, and concise AI assistant. "
25
+ "You answer in the same language as the user and you explain things clearly. "
26
+ "If the user is Miriam, you can hablarle en confianza como amiga :)"
27
+ )
28
+
29
+ # 4. Función para construir el prompt con contexto + historial
30
+ def build_prompt(system_prompt, context, history, user_message):
31
  """
32
+ system_prompt: texto de instrucciones del sistema.
33
+ context: contexto adicional que el usuario pega (documentos, notas, etc).
34
+ history: lista de pares (user, assistant).
35
+ user_message: mensaje actual del usuario.
36
  """
37
+ # Encabezado tipo Llama chat
38
+ prompt_parts = []
39
+
40
+ # System
41
+ if system_prompt:
42
+ prompt_parts.append(f"<|system|>\n{system_prompt}\n")
43
+
44
+ # Contexto extra
45
+ if context:
46
+ prompt_parts.append(
47
+ "<|system|>\nThe following is extra context that may be useful. "
48
+ "Use it to answer the user if relevant:\n"
49
+ f"{context}\n"
50
+ )
51
+
52
+ # Historial
53
+ for old_user, old_assistant in history:
54
+ prompt_parts.append(f"<|user|>\n{old_user}\n")
55
+ prompt_parts.append(f"<|assistant|>\n{old_assistant}\n")
56
+
57
+ # Mensaje actual
58
+ prompt_parts.append(f"<|user|>\n{user_message}\n")
59
+ prompt_parts.append("<|assistant|>\n")
60
+
61
+ full_prompt = "".join(prompt_parts)
62
+ return full_prompt
63
+
64
+ # 5. Función de generación
65
+ def generate_answer(system_prompt, context, message, history, max_tokens, temperature, top_p):
66
+ # history viene como lista de listas: [[user, assistant], [user, assistant], ...]
67
+ # Gradio suele usar este formato.
68
+ if system_prompt is None or system_prompt.strip() == "":
69
+ system_prompt = DEFAULT_SYSTEM_PROMPT
70
+
71
+ prompt = build_prompt(system_prompt, context, history, message)
72
+
73
+ inputs = tokenizer(
74
+ prompt,
75
+ return_tensors="pt",
76
+ add_special_tokens=False
77
+ ).to(DEVICE)
78
+
79
+ with torch.no_grad():
80
+ output_tokens = model.generate(
81
+ **inputs,
82
+ max_new_tokens=int(max_tokens),
83
+ do_sample=True,
84
+ temperature=float(temperature),
85
+ top_p=float(top_p),
86
+ pad_token_id=tokenizer.eos_token_id
87
+ )
88
+
89
+ # Cortamos el prompt inicial y nos quedamos sólo con la respuesta nueva
90
+ generated_text = tokenizer.decode(output_tokens[0], skip_special_tokens=False)
91
+
92
+ # Buscamos el último tag de <|assistant|> y tomamos lo que sigue
93
+ split_token = "<|assistant|>"
94
+ if split_token in generated_text:
95
+ answer = generated_text.split(split_token)[-1]
96
+ else:
97
+ # fallback: todo el texto (no ideal, pero por si acaso)
98
+ answer = generated_text
99
+
100
+ # Limpieza sencilla
101
+ answer = answer.replace("</s>", "").strip()
102
+
103
+ # Actualizamos historial: agregamos el último turno
104
+ history = history + [[message, answer]]
105
+
106
+ return answer, history
107
+
108
+ # 6. Función wrapper para Gradio (usa el historial del Chatbot)
109
+ def chat_fn(message, history, system_prompt, context, max_tokens, temperature, top_p):
110
+ if history is None:
111
+ history = []
112
+ answer, history = generate_answer(system_prompt, context, message, history, max_tokens, temperature, top_p)
113
+ return answer, history
114
+
115
+ # 7. Construcción de la UI en Gradio
116
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
117
+ gr.Markdown(
118
+ """
119
+ # ✨ Iris – Tu modelo finetuneado
120
+ Chatea con tu modelo, agrega contexto y ajusta el comportamiento del sistema.
121
+ """
122
+ )
123
+
124
+ with gr.Row():
125
+ with gr.Column(scale=3):
126
+ chatbot = gr.Chatbot(
127
+ label="Chat con Iris",
128
+ height=450
129
+ )
130
+ msg = gr.Textbox(
131
+ label="Mensaje",
132
+ placeholder="Escribe aquí tu pregunta...",
133
+ )
134
+ send_btn = gr.Button("Enviar ✉️", variant="primary")
135
+
136
+ with gr.Column(scale=2):
137
+ system_prompt_box = gr.Textbox(
138
+ label="System prompt",
139
+ value=DEFAULT_SYSTEM_PROMPT,
140
+ lines=6
141
+ )
142
+ context_box = gr.Textbox(
143
+ label="Contexto adicional (opcional)",
144
+ placeholder="Pega aquí notas, documentos o datos que quieras que Iris use como contexto.",
145
+ lines=10
146
+ )
147
+
148
+ max_tokens_slider = gr.Slider(
149
+ label="Máx. tokens de respuesta",
150
+ minimum=64,
151
+ maximum=2048,
152
+ value=512,
153
+ step=32
154
+ )
155
+ temperature_slider = gr.Slider(
156
+ label="Temperature",
157
+ minimum=0.1,
158
+ maximum=1.5,
159
+ value=0.7,
160
+ step=0.1
161
+ )
162
+ top_p_slider = gr.Slider(
163
+ label="Top-p",
164
+ minimum=0.1,
165
+ maximum=1.0,
166
+ value=0.9,
167
+ step=0.05
168
+ )
169
+
170
+ clear_btn = gr.Button("Limpiar historial 🧹")
171
+
172
+ # Eventos
173
+ def user_submit(user_message, chat_history):
174
+ # Sólo para mostrar de inmediato el mensaje del usuario
175
+ if chat_history is None:
176
+ chat_history = []
177
+ return "", chat_history + [[user_message, None]]
178
+
179
+ send_btn.click(
180
+ fn=chat_fn,
181
+ inputs=[msg, chatbot, system_prompt_box, context_box, max_tokens_slider, temperature_slider, top_p_slider],
182
+ outputs=[chatbot, chatbot],
183
+ )
184
 
185
+ msg.submit(
186
+ fn=user_submit,
187
+ inputs=[msg, chatbot],
188
+ outputs=[msg, chatbot],
189
+ queue=False
190
+ ).then(
191
+ fn=chat_fn,
192
+ inputs=[msg, chatbot, system_prompt_box, context_box, max_tokens_slider, temperature_slider, top_p_slider],
193
+ outputs=[chatbot, chatbot],
194
+ )
195
 
196
+ clear_btn.click(
197
+ lambda: [],
198
+ None,
199
+ chatbot
200
+ )
201
 
202
+ # 8. Lanzar la app (HF Spaces la llama con `python app.py`)
203
  if __name__ == "__main__":
204
  demo.launch()