tecuhtli's picture
Actualización del modelo / README / pesos / etc.
ef4a4c8
#***************************************************************************
# Mori (tech-only) — Streamlit App sin sidebar ni social, con RAG opcional
#***************************************************************************
import os, sys, warnings, json, joblib, random, re, unicodedata, uuid, torch, csv
import numpy as np
os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0"
import streamlit as st
import datetime as dt
from pathlib import Path
import torch
import numpy as np
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForSequenceClassification
from huggingface_hub import hf_hub_download
from sentence_transformers import SentenceTransformer # RAG embeddings
# =========================
# Configuración general
# =========================
HF_TOKEN = os.environ.get("HF_TOKEN") # Token privado (colócalo en Secrets o variable de entorno)
#***************************************************************************
# Sidebar controls for generation params
#***************************************************************************
def sidebar_params():
with st.sidebar:
st.title("🎮 Adjustments (T5-Base)")
ss = st.session_state
# Defaults (solo 1ª vez)
# Estado inicial: ocultar ajustes avanzados
ss = st.session_state
if "show_llm_controls" not in ss:
ss.show_llm_controls = False
ss.setdefault("persona", "Normal")
ss.setdefault("mode", "beam") # 'beam' | 'sampling'
ss.setdefault("max_new", 128)
ss.setdefault("min_tok", 16)
ss.setdefault("no_repeat", 3)
ss.setdefault("num_beams", 4)
ss.setdefault("length_penalty", 1.0)
ss.setdefault("temperature", 0.7)
ss.setdefault("top_p", 0.9)
ss.setdefault("repetition_penalty", 1.0)
ss.setdefault("show_llm_controls", True) # Toggle principal
# ----------------------------
# Personalidad (presets)
# ----------------------------
st.header("💡 Predefined Personalities")
c1, c2 = st.columns(2)
with c1:
if st.button("Normal 🧐", use_container_width=True):
ss.update({
"persona": "Normal",
"mode": "beam",
"num_beams": 1,
"max_new": 92,
"min_tok": 32,
"no_repeat": 3,
"length_penalty": .3,
"temperature": 0.4,
"top_p": 0.9,
"repetition_penalty": .4,
})
st.rerun()
with c2:
if st.button("Enthusiastic 😃", use_container_width=True):
ss.update({
"persona": "Enthusiastic", # <- corregido
"mode": "sampling",
"max_new": 192,
"min_tok": 48,
"no_repeat": 3,
"temperature": .8,
"top_p": 0.95,
"repetition_penalty": 1.0,
})
st.rerun()
st.caption(f"Selected Personality: **{ss.persona}**")
# ----------------------------
# Botón para mostrar/ocultar parámetros
# ----------------------------
if st.button(("🔼 Hide" if ss.show_llm_controls else "🔽 Show") + " Advanced Settings"):
ss.show_llm_controls = not ss.show_llm_controls
st.rerun()
# ----------------------------
# Controles del modelo (sliders, estrategia, etc.)
# ----------------------------
if ss.show_llm_controls:
st.header("⚙️ Manual Adjustments")
st.subheader("📝 Text Generation")
picked = st.radio(
"Strategy",
["Beam search (stable)", "Sampling (creative)"],
index=0 if ss.mode == "beam" else 1,
help="https://huggingface.co/docs/transformers/generation_strategies"
)
ss.mode = "beam" if picked.startswith("Beam") else "sampling"
st.subheader("🔧 LLM text generation parameters")
ss.max_new = st.slider(
"max_new_tokens", 16, 256, int(ss.max_new), step=8,
help="https://huggingface.co/docs/transformers/main_classes/text_generation"
)
ss.min_tok = st.slider(
"min_tokens", 0, int(ss.max_new), int(ss.min_tok),
help="https://huggingface.co/docs/transformers/main_classes/text_generation"
)
ss.no_repeat = st.slider(
"no_repeat_ngram_size", 0, 6, int(ss.no_repeat),
help="https://huggingface.co/docs/transformers/main_classes/text_generation"
)
# Subcontroles según modo
if ss.mode == "beam":
ss.num_beams = st.slider(
"num_beams", 2, 8, int(ss.num_beams),
help="https://huggingface.co/docs/transformers/main_classes/text_generation"
)
ss.length_penalty = st.slider(
"length_penalty", 0.0, 2.0, float(ss.length_penalty),
step=0.1, help="https://huggingface.co/docs/transformers/main_classes/text_generation"
)
else:
ss.temperature = st.slider(
"temperature", 0.1, 1.5, float(ss.temperature),
step=0.05, help="https://huggingface.co/docs/transformers/main_classes/text_generation"
)
ss.top_p = st.slider(
"top_p", 0.5, 1.0, float(ss.top_p),
step=0.01, help="https://huggingface.co/docs/transformers/main_classes/text_generation"
)
if "last_prompt" in st.session_state and st.session_state["last_prompt"]:
with st.expander("Show generated prompt"):
st.text_area(
"Prompt actual:",
st.session_state["last_prompt"],
height=200,
disabled=True
)
else:
st.caption("👉 No prompt is available yet.")
# ----------------------------
# Construir diccionario de parámetros
# ----------------------------
params = {
"persona": ss.persona,
"mode": ss.mode,
"max_new_tokens": int(ss.max_new),
"min_tokens": int(ss.min_tok),
"no_repeat_ngram_size": int(ss.no_repeat),
"repetition_penalty": float(ss.repetition_penalty),
}
if ss.mode == "beam":
params.update({
"num_beams": int(ss.num_beams),
"length_penalty": float(ss.length_penalty),
})
else:
params.update({
"temperature": float(ss.temperature),
"top_p": float(ss.top_p),
})
return params
#***************************************************************************
# Functions
#***************************************************************************
def truncate_sentences(text: str, max_sentences: int = 4) -> str:
_SENT_SPLIT = re.compile(r'(?<=[\.\!\?…])\s+')
s = text.strip()
if not s: return s
parts = _SENT_SPLIT.split(s)
cut = " ".join(parts[:max_sentences]).strip()
if cut and cut[-1] not in ".!?…": cut += "."
return cut
def _load_json_safe(path: Path, fallback: dict) -> dict:
try:
with open(path, "r", encoding="utf-8") as f:
return json.load(f)
except Exception:
return fallback
# Function to clean the question field
def limpiar_input():
st.session_state["entrada"] = ""
# ✅ Corrige la ruta correctamente desde Scripts hacia Models
def get_model_path(folder_name):
return Path("Models") / folder_name
# Function to save user interaction
def saving_interaction(question, response, context, user_id):
'''
inputs:
question --> User input question
response --> Assistant response to the user question
context --> Context related to the user input, found by the trained classifier
user_id --> ID for the current user (Unique ID per session)
'''
timestamp = dt.datetime.now().isoformat()
stats_dir = Path("Statistics")
stats_dir.mkdir(parents=True, exist_ok=True)
archivo_csv = stats_dir / "conversaciones_log.csv"
existe_csv = archivo_csv.exists()
with open(archivo_csv, mode="a", encoding="utf-8", newline="") as f_csv:
writer = csv.writer(f_csv)
if not existe_csv:
writer.writerow(["timestamp", "user_id", "contexto", "pregunta", "respuesta"])
writer.writerow([timestamp, user_id, context, question, response])
archivo_jsonl = stats_dir / "conversaciones_log.jsonl"
with open(archivo_jsonl, mode="a", encoding="utf-8") as f_jsonl:
registro = {
"timestamp": timestamp,
"user_id": user_id,
"context": context,
"pregunta": question,
"respuesta": response}
f_jsonl.write(json.dumps(registro, ensure_ascii=False) + "\n")
# Function to load models within the huggingface repositories space
@st.cache_resource
def load_model(path_str):
path = Path(path_str).resolve()
tokenizer = AutoTokenizer.from_pretrained(path, local_files_only=True)
model = AutoModelForSeq2SeqLM.from_pretrained(path, local_files_only=True)
return model, tokenizer
#-------------------------------------------------------------------------
# Function to correct Spanish sentences' punctuation and missing characters
#-------------------------------------------------------------------------
def polish_spanish(s: str) -> str:
s = unicodedata.normalize("NFC", s).strip()
s = re.sub(r'\s*[\[\(]\s*Assistant\s+(?:Social|T[eé]nico|T[eé]cnico)\s*[\]\)]\s*', '', s, flags=re.I)
fixes = [
(r'(?i)(^|\W)T\s+puedes(?P<p>[^\w]|$)', r'\1Tú puedes\g<p>'),
(r'(?i)(^|\W)T\s+(ya|eres|estas|estás|tienes|puedes)\b', r'\1Tú \2'),
(r'(?i)\bclaro que s(?:i|í)?\b(?P<p>[,.\!?…])?', r'Claro que sí\g<p>'),
(r'(?i)(^|\s)si,', r'\1Sí,'),
(r'(?i)(\beso\s+)s(\s+est[áa]\b)', r'\1sí\2'),
(r'(?i)(^|[\s,;:])s(\s+es\b)', r'\1sí\2'),
(r'(?i)\btiles\b', 'útiles'),
(r'(?i)\butiles\b', 'útiles'),
(r'(?i)\butil\b', 'útil'),
(r'(?i)\baqui\b', 'aquí'),
(r'(?i)\baqu\b(?=\s+estoy\b)', 'aquí'),
(r'(?i)\balgn\b', 'algún'),
(r'(?i)\balgun\b', 'algún'),
(r'(?i)\bAnimo\b', 'Ánimo'),
(r'(?i)\bcario\b', 'cariño'),
(r'(?i)\baprendisaje\b', 'aprendizaje'),
(r'(?i)\bmanana\b', 'mañana'),
(r'(?i)\bmaana\b', 'mañana'),
(r'(?i)\benergia\b', 'energía'),
(r'(?i)\benerga\b', 'energía'),
(r'(?i)\bextrano\b', 'extraño'),
(r'(?i)\bextrana\b', 'extraña'),
(r'(?i)\bextranar\b', 'extrañar'),
(r'(?i)\bextranarte\b', 'extrañarte'),
(r'(?i)\bextranas\b', 'extrañas'),
(r'(?i)\bextranos\b', 'extraños'),
(r'(?i)\baqu\b', 'aquí'),
(r'(?i)\baqui\b', 'aquí'),
(r'(?i)\bestare\b', 'estaré'),
(r'(?i)\bclarn\b', 'clarín'),
(r'(?i)\bclarin\b', 'clarín'),
(r'(?i)\bclar[íi]n\s+cornetas\b', 'clarín cornetas'),
(r'(?i)(^|\s)s([,.;:!?])', r'\1Sí\2'),
(r'(?i)\bfutbol\b', 'fútbol'),
(r'(?i)(^|\s)as(\s+se\b)', r'\1Así\2'),
(r'(?i)(^|\s)s(\s+orientarte\b)', r'\1sí\2'),
(r'(?i)\bbuen dia\b', 'buen día'),
(r'(?i)\bgran dia\b', 'gran día'),
(r'(?i)\bdias\b', 'días'),
(r'(?i)\bdia\b', 'día'),
(r'(?i)\bgran da\b', 'gran día'),
(r'(?i)\bacompa?a(r|rte|do|da|dos|das)?\b', r'acompaña\1'),
(r'(?i)(^|\s)as([,.;:!?]|\s|$)', r'\1así\2'),
(r'(?i)(^|\s)S lo se\b', r'\1Sí lo sé'),
(r'(?i)(^|\s)S lo sé\b', r'\1Sí lo sé'),
(r'(?i)\bcudese\b', 'cuídese'),
(r'(?i)\bpequeo\b', 'pequeño'),
(r'(?i)\bpequea\b', 'pequeña'),
(r'(?i)\bpequeos\b', 'pequeños'),
(r'(?i)\bpequeas\b', 'pequeñas'),
(r'(?i)\bunico\b', 'único'),
(r'(?i)\bunica\b', 'única'),
(r'(?i)\bunicos\b', 'únicos'),
(r'(?i)\bunicas\b', 'únicas'),
(r'(?i)\bnico\b', 'único'),
(r'(?i)\bnica\b', 'única'),
(r'(?i)\bnicos\b', 'únicos'),
(r'(?i)\bnicas\b', 'únicas'),
(r'(?i)\bestadstico\b', 'estadístico'),
(r'(?i)\bestadstica\b', 'estadística'),
(r'(?i)\bestadsticos\b', 'estadísticos'),
(r'(?i)\bestadsticas\b', 'estadísticas'),
(r'(?i)\bcudate\b', 'cuídate'),
(r'(?i)\bcuidate\b', 'cuídate'),
(r'(?i)\bcuidese\b', 'cuídese'),
(r'(?i)\bcudese\b', 'cuídese'),
(r'(?i)\bcuidense\b', 'cuídense'),
(r'(?i)\bcudense\b', 'cuídense'),
(r'(?i)\bgracias por confiar en m\b', 'gracias por confiar en mí'),
(r'(?i)\bcada dia\b', 'cada día'),
(r'(?i)\bcada da\b', 'cada día'),
(r'(?i)\bsegun\b', 'según'),
(r'(?i)\bcaracteristica(s)?\b', r'característica\1'),
(r'(?i)\bcaracterstica(s)?\b', r'característica\1'),
(r'(?i)\b([a-záéíóúñ]+)cion\b', r'\1ción'),
(r'(?i)\bdeterminacio\b', 'determinación'),
]
for pat, rep in fixes:
s = re.sub(pat, rep, s)
s = re.sub(r'(?i)^eso es todo!(?P<r>(\s|$).*)', r'¡Eso es todo!\g<r>', s)
def add_opening_q(m):
cuerpo = m.group('qbody')
if '¿' in cuerpo:
return m.group(0)
return f"{m.group('pre')}¿{cuerpo}"
s = re.sub(r'(?P<pre>(^|[\.!\…]\s+))(?P<qbody>[^?]*\?)', add_opening_q, s)
def _open_exclam(m):
palabra = m.group('w')
resto = m.group('r') or ''
return f'¡{palabra}!{resto}'
s = re.sub(r'(?i)^(?P<w>(hola|gracias|genial|perfecto|claro|por supuesto|con gusto|listo|vaya|wow|tu puedes|tú puedes|clarín|clarin|clarín cornetas))!(?P<r>(\s|$).*)',_open_exclam, s)
s = re.sub(r'\s+', ' ', s).strip()
if s and s[-1] not in ".!?…":
s += "."
return s
#-------------------------------------------------------------------------
# Function to remove repeated input in the Model answer
#-------------------------------------------------------------------------
def anti_echo(response: str, user_text: str) -> str:
rn = normalize_for_route(response)
un = normalize_for_route(user_text)
def _clean_leading(s: str) -> str:
s = re.sub(r'^\s*[,;:\-–—]\s*', '', s)
s = re.sub(r'^\s+', '', s)
return s
if len(un) >= 4 and rn.startswith(un):
cut = re.sub(r'^\s*[^,;:\.\!\?]{0,120}[,;:\-]\s*', '', response).lstrip()
if cut and cut != response:
return _clean_leading(cut)
return _clean_leading(response[len(user_text):])
return response
#-------------------------------------------------------------------------
# Normalization helpers
#-------------------------------------------------------------------------
def normalize_for_route(s: str) -> str:
s = unicodedata.normalize("NFKD", s)
s = "".join(ch for ch in s if not unicodedata.combining(ch))
s = re.sub(r"[^\w\s-]", " ", s, flags=re.UNICODE)
s = re.sub(r"\s+", " ", s).strip().lower()
return s
_Q_STARTERS = {
"como","que","quien","quienes","cuando","donde","por que","para que",
"cual","cuales","cuanto","cuantos","cuanta","cuantas"
}
_EXC_TRIGGERS = {"motiva","motivame","animate","animame","animo","ayudame","ayudame porfa", "clarin", "clarín", "clarinete", "clarin cornetas"}
SPECIAL_NOPUNCT = {"kiubo", "quiubo", "que chido", "qué chido", "que buena onda"}
_Q_VERB_STARTERS = {"eres","estas","estás","puedes","sabes","tienes","quieres","conoces",
"crees","piensas","dirias","dirías","podrias","podrías","podras","podrás"}
#-------------------------------------------------------------------------
# Punctuation helpers
#-------------------------------------------------------------------------
def needs_question_marks(norm: str) -> bool:
if "?" in norm: return False
for w in _Q_STARTERS:
if norm.startswith(w + " ") or norm == w:
return True
return False
def needs_exclam(norm: str) -> bool:
if "!" in norm: return False
return any(t in norm for t in _EXC_TRIGGERS)
#-------------------------------------------------------------------------
# Greetings detection
#-------------------------------------------------------------------------
def is_slang_greeting(norm: str) -> bool:
SHORT = {
"que pex", "que onda", "ke pex", "k pex", "q onda",
"kiubo", "quiubo", "quiubole", "quiúbole", "kionda", "q onda", "k onda",
"que rollo", "ke onda", "que show", "que tranza"
}
if norm in SHORT: return True
if re.match(r"^(q|k|ke|que)\s+(pex|onda|rollo|show|tranza)\b", norm): return True
if re.match(r"^(kiubo|quiubo|quiubole|quiúbole|quiubol[e]?)\b", norm): return True
return False
#-------------------------------------------------------------------------
# Capitalization & autopunct
#-------------------------------------------------------------------------
def capitalize_spanish(s: str) -> str:
s = s.strip()
i = 0
while i < len(s) and not s[i].isalpha():
i += 1
if i < len(s):
s = s[:i] + s[i].upper() + s[i+1:]
return s
def smart_autopunct(user_text: str) -> str:
s = user_text.strip()
if len(s) > 20:
return capitalize_spanish(s)
norm = normalize_for_route(s)
if norm in SPECIAL_NOPUNCT:
s = re.sub(r'[¿?!¡]+', '', s).strip()
return capitalize_spanish(s)
if norm.startswith("y si "):
s = f"¿{s}?"
return capitalize_spanish(s)
if "?" in s and "¿" not in s:
s = "¿" + s
return capitalize_spanish(s)
if "!" in s and "¡" not in s:
s = "¡" + s
return capitalize_spanish(s)
if is_slang_greeting(norm):
s = f"¡{s}!"
return capitalize_spanish(s)
if needs_question_marks(norm):
s = f"¿{s}?"
return capitalize_spanish(s)
toks = norm.split()
if toks and toks[0] in _Q_VERB_STARTERS:
s = f"¿{s}?"
return capitalize_spanish(s)
if re.match(r"^(me\s+ayudas?|me\s+puedes|podrias?|podras?)\b", norm):
s = f"¿{s}?"
return capitalize_spanish(s)
if needs_exclam(norm):
s = f"¡{s}!"
return capitalize_spanish(s)
return capitalize_spanish(s)
#-------------------------------------------------------------------------
# Seeds & helpers
#-------------------------------------------------------------------------
def set_seeds(seed: int = 42):
random.seed(seed); np.random.seed(seed); torch.manual_seed(seed)
if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
# --- Personalidades (solo estilo en prompt; parámetros ya vienen del sidebar) ---
def persona_style_prompt(persona: str, domain: str) -> str:
"""Instrucción breve de estilo según personalidad y dominio (technical/social)."""
if persona == "Enthusiastic":
return (
"Responde de forma creativa, usa al menos 232 palabras. ")
if persona == "Normal": # ya no se usa, pero por compatibilidad
return ""
return "" # Assistant response
#-------------------------------------------------------------------------
# Classifier
#-------------------------------------------------------------------------
def classify_context(question, label_classes, model, tokenizer, device):
model = model.to(device)
inputs = tokenizer(question, return_tensors="pt", padding=True, truncation=True, max_length=128)
inputs = {k: v.to(device) for k, v in inputs.items()}
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
pred_intent = torch.argmax(logits, dim=1).item()
predicted_label = label_classes[pred_intent]
return predicted_label
#-------------------------------------------------------------------------
# Chatbot response for technical contexts using a Hugging Face model
#-------------------------------------------------------------------------
def technical_asnwer(question, context, model, tokenizer, device, gen_params=None):
model = model.to(device).eval()
persona_name = (gen_params or {}).get("persona", st.session_state.get("persona", "Normal"))
style = persona_style_prompt(persona_name, "technical")
# Promp Engineering para ayudar al asistente a encontrar la mejor respuesta
input_text = f"{style}Context: {context} [SEP] Question: {question}."
st.session_state["last_prompt"] = input_text # o prompt
st.session_state["just_generated"] = True
#st.rerun()
enc = tokenizer(input_text, return_tensors="pt", padding=True, truncation=True, max_length=256).to(device)
bad_words = ["["]
bad_ids = [tokenizer(bw, add_special_tokens=False).input_ids for bw in bad_words]
# --- construir kwargs de generación, SIN tocar nada por personalidad ---
max_new = int((gen_params).get("max_new_tokens"))
min_new = int((gen_params).get("min_tokens")) # <- ahora SIEMPRE min_new_tokens
no_repeat = int((gen_params).get("no_repeat_ngram_size"))
rep_pen = float((gen_params).get("repetition_penalty"))
mode = (gen_params or {}).get("mode", "beam")
if mode == "sampling":
temperature = float((gen_params or {}).get("temperature", 0.7))
top_p = float((gen_params or {}).get("top_p", 0.9))
kwargs = dict(
do_sample=True,
num_beams=1,
temperature=max(0.1, temperature),
top_p=min(1.0, max(0.5, top_p)),
max_new_tokens=max_new,
min_new_tokens=max(0, min_new), # 👈 consistente
no_repeat_ngram_size=no_repeat,
repetition_penalty=max(1.0, rep_pen),
bad_words_ids=bad_ids,
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.pad_token_id,
)
else:
num_beams = max(2, int((gen_params or {}).get("num_beams", 4)))
length_penalty = float((gen_params or {}).get("length_penalty", 1.0))
kwargs = dict(
do_sample=False,
num_beams=num_beams,
length_penalty=length_penalty,
max_new_tokens=max_new,
min_new_tokens=max(0, min_new), # 👈 también aquí (no min_length)
no_repeat_ngram_size=no_repeat,
repetition_penalty=max(1.0, rep_pen),
bad_words_ids=bad_ids,
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.pad_token_id,
)
out_ids = model.generate(
input_ids=enc["input_ids"], attention_mask=enc["attention_mask"], **kwargs
)
text = tokenizer.decode(out_ids[0], skip_special_tokens=True)
if persona_name == "Normal":
text = truncate_sentences(text, max_sentences=1)
st.session_state["last_response"] = text
#st.rerun()
return polish_spanish(text)
#-------------------------------------------------------------------------
# Chatbot response for social contexts using a Hugging Face model
#-------------------------------------------------------------------------
def social_asnwer(question, model, tokenizer, device, gen_params=None, block_web=True):
model = model.to(device).eval()
persona_name = (gen_params or {}).get("persona", st.session_state.get("persona", "Normal"))
prompt_type = st.session_state.get("prompt_type", "Zero-shot")
prompt = question
st.session_state["last_prompt"] = prompt # o prompt
st.session_state["just_generated"] = True
enc = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True, max_length=192).to(device)
bad_words = ["[", "Thanks", "thank you"]
if block_web:
bad_words += ["website", "http", "www", ".com"]
bad_ids = [tokenizer(bw, add_special_tokens=False).input_ids for bw in bad_words]
max_new = int((gen_params).get("max_new_tokens"))
min_tokens = int((gen_params).get("min_tokens"))
min_length = int(enc["input_ids"].shape[1]) + max(0, min_tokens)
no_repeat = int((gen_params).get("no_repeat_ngram_size"))
rep_pen = float((gen_params).get("repetition_penalty"))
mode = (gen_params or {}).get("mode", "beam")
if mode == "sampling":
temperature = float((gen_params or {}).get("temperature", 0.7))
top_p = float((gen_params or {}).get("top_p", 0.9))
kwargs = dict(
do_sample=True, num_beams=1,
temperature=max(0.1, temperature),
top_p=min(1.0, max(0.5, top_p)),
max_new_tokens=max_new,
#min_length=min_length,
min_new_tokens=max(0, min_tokens),
no_repeat_ngram_size=no_repeat,
repetition_penalty=max(1.0, rep_pen),
bad_words_ids=bad_ids,
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.pad_token_id,
)
else:
num_beams = max(2, int((gen_params or {}).get("num_beams", 4)))
length_penalty = float((gen_params or {}).get("length_penalty", 1.0))
kwargs = dict(
do_sample=False, num_beams=num_beams, length_penalty=length_penalty,
max_new_tokens=max_new,
#min_length=min_length,
min_new_tokens=max(0, min_tokens), # <- usar min_new_tokens
no_repeat_ngram_size=no_repeat,
repetition_penalty=max(1.0, rep_pen),
bad_words_ids=bad_ids,
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.pad_token_id,
)
out_ids = model.generate(
input_ids=enc["input_ids"], attention_mask=enc["attention_mask"], **kwargs
)
text = tokenizer.decode(out_ids[0], skip_special_tokens=True)
if persona_name == "Normal":
text = truncate_sentences(text, max_sentences=2)
#text = anti_echo(text, question)
text = polish_spanish(text)
text = capitalize_spanish(text)
st.session_state["last_response"] = text
#st.rerun()
return text
#-------------------------------------------------------------------------
# Rule overrides
#-------------------------------------------------------------------------
def rule_intent_override(user_text: str, predicted_label: str) -> str:
n = normalize_for_route(user_text)
if re.fullmatch(r"(motivame|motiva|animame|animo|ayudame|que tranza|qué tranza|que tranza)", n):
return "social"
return predicted_label
#-------------------------------------------------------------------------
# Router
#-------------------------------------------------------------------------
def contextual_asnwer(question, label_classes, context_model, cont_tok,
tec_model, tec_tok, soc_model, soc_tok, device, gen_params=None, block_web=True):
context = classify_context(question, label_classes, context_model, cont_tok, device)
context = rule_intent_override(question, context)
context_icons = {
"social": "💬", "modelos": "🔧", "evaluación": "📏", "optimización": "⚙️",
"visualización": "📈", "aprendizaje": "🧠", "vida digital": "🧑‍💻",
"estadística": "📊", "infraestructura": "🖥", "datos": "📂", "transformación digital": "🌀"}
icon = context_icons.get(context, "🧠")
if gen_params and "seed" in gen_params:
set_seeds(gen_params["seed"])
if context == "social":
return social_asnwer(question, soc_model, soc_tok, device, gen_params=gen_params, block_web=block_web), context
else:
return technical_asnwer(question, context, tec_model, tec_tok, device, gen_params=gen_params), context
#***************************************************************************
# MAIN
#***************************************************************************
if __name__ == '__main__':
# --- Estado que debe persistir en todos los reruns ---
ss = st.session_state
ss.setdefault("historial", [])
ss.setdefault("last_prompt", "")
ss.setdefault("last_response", "")
ss.setdefault("just_generated", False)
# Sidebar (control total)
GEN_PARAMS = sidebar_params()
GEN_PARAMS["persona"] = st.session_state.persona # por si acaso
# Setting historial for the current user
#if "historial" not in st.session_state:
# st.session_state.historial = []
# Assigning a new ID to the current user
if "user_id" not in st.session_state:
st.session_state["user_id"] = str(uuid.uuid4())[:8]
# Loading classifier encoder classes:
labels_path = hf_hub_download(repo_id="tecuhtli/assistant-classifier-bert", filename="context_labels.pkl", use_auth_token=HF_TOKEN)
label_classes = joblib.load(labels_path)
# Loading Saved Models
# Modelo Contexto
context_model = AutoModelForSequenceClassification.from_pretrained("tecuhtli/assistant-classifier-bert", use_auth_token=HF_TOKEN)
cont_tok = AutoTokenizer.from_pretrained("tecuhtli/assistant-classifier-bert", use_auth_token=HF_TOKEN)
# Modelo Técnico
tec_tok = AutoTokenizer.from_pretrained("tecuhtli/assistant-technical-t5", use_auth_token=HF_TOKEN)
tec_model = AutoModelForSeq2SeqLM.from_pretrained("tecuhtli/assistant-technical-t5", use_auth_token=HF_TOKEN)
# Modelo Social
soc_tok = AutoTokenizer.from_pretrained("tecuhtli/assistant-social-t5", use_auth_token=HF_TOKEN)
soc_model = AutoModelForSeq2SeqLM.from_pretrained("tecuhtli/assistant-social-t5", use_auth_token=HF_TOKEN)
# Available Device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Defining Assistant Presentation
st.title("🤖 Your Personal Assistant 🎓")
st.caption("🙋🏽‍ You can ask me about technical concepts such as visualization, data cleaning, BI, and more.")
st.caption("🙇🏽 I can *only* understand and answer in Spanish (🦅🇲🇽🌵).")
st.caption("➡️ At this stage, I can respond to simple questions such as:")
st.caption(" • ¿Cómo estás? • ¿Qué es...? • Explícame algo • Define algo • ¿Para qué sirve...?")
st.caption("😊 If you want to know me better, visit: [hazutecuhtli.github.io](https://github.com/hazutecuhtli/LLMs_FineTuned_Chatbot)")
st.markdown("<br>", unsafe_allow_html=True)
st.caption("✏️ Type **'salir'** to exit.")
# 🔁 Limpieza segura antes del formulario
if st.session_state.pop("_clear_entrada", False):
if "entrada" in st.session_state:
del st.session_state["entrada"]
# 🧠 Flash de respuesta (la guardamos, pero la mostraremos después del form)
_flash = st.session_state.pop("_flash_response", None)
with st.form("formulario_assistant"):
user_question = st.text_area("📝 Escribe tu pregunta aquí", key="entrada", height=100)
submitted = st.form_submit_button("Responder")
if submitted:
if not user_question:
st.info("Chatbot: ¿Podrías repetir eso? No entendí bien 😅")
else:
response, context = contextual_asnwer(
user_question, label_classes, context_model, cont_tok,
tec_model, tec_tok, soc_model, soc_tok, device,
gen_params=GEN_PARAMS, block_web=True,
)
# 🧠 Guarda historial
hora_actual = dt.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
st.session_state.historial.append(("Tú", user_question, hora_actual))
hora_actual = dt.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
st.session_state.historial.append(("Assistant", response, hora_actual))
# 💾 Guarda conversación
saving_interaction(user_question, response, context, st.session_state["user_id"])
# 🟩 Guarda respuesta para mostrar después del rerun
st.session_state["_flash_response"] = response
# 🧼 Limpieza del textarea en el próximo ciclo
st.session_state["_clear_entrada"] = True
# ♻️ Forzar refresh (sidebar verá el nuevo prompt)
st.rerun()
# -----------------------------------------------------------
# 💬 Mostrar la respuesta actual (flash) justo aquí ↓↓↓
# -----------------------------------------------------------
if _flash:
st.success(_flash)
# Mostrar último mensaje (opcional, arriba de todo)
#if st.session_state.get("just_generated"):
# if st.session_state["last_response"]:
# st.success(st.session_state["last_response"])
# st.session_state["just_generated"] = False
# ... formulario y lógica de respuesta ...
# 🔁 Historial con estilo chat y contenedor con scroll
if st.session_state.historial:
st.markdown("---")
# 💾 Botón de descarga arriba del historial
lineas = []
for msg in reversed(st.session_state.historial):
if len(msg) == 3:
autor, texto, hora = msg
lineas.append(f"[{hora}] {autor}: {texto}")
else:
autor, texto = msg
lineas.append(f"{autor}: {texto}")
texto_chat = "\n\n".join(lineas)
st.download_button(
label="💾 Descargar conversación como .txt",
data=texto_chat,
file_name="conversacion_assistant.txt",
mime="text/plain",
use_container_width=True
)
# 🪟 Contenedor con scroll y burbujas
st.markdown(
"""
<div id="chat-container" style="
max-height: 400px;
overflow-y: auto;
padding: 10px;
border: 1px solid #333;
border-radius: 10px;
background: linear-gradient(180deg, #0e0e0e 0%, #1b1b1b 100%);
margin-top: 10px;
">
""",
unsafe_allow_html=True
)
for msg in reversed(st.session_state.historial):
if len(msg) == 3:
autor, texto, _ = msg
else:
autor, texto = msg
if autor == "Tú":
st.markdown(
f"""
<div style="
text-align: right;
background-color: #2d2d2d;
color: #e6e6e6;
padding: 10px 14px;
border-radius: 12px;
margin: 6px 0;
border: 1px solid #3a3a3a;
display: inline-block;
max-width: 80%;
float: right;
clear: both;
">
🧍‍♂️ <b>{autor}:</b> {texto}
</div>
""",
unsafe_allow_html=True
)
else:
st.markdown(
f"""
<div style="
text-align: left;
background-color: #162b1f;
color: #d9ead3;
padding: 10px 14px;
border-radius: 12px;
margin: 6px 0;
border: 1px solid #264d36;
display: inline-block;
max-width: 80%;
float: left;
clear: both;
">
🤖 <b>{autor}:</b> {texto}
</div>
""",
unsafe_allow_html=True
)
st.markdown("</div>", unsafe_allow_html=True)
#***************************************************************************
# FIN
#***************************************************************************