|
|
|
|
|
|
|
|
|
|
|
import os, sys, warnings, json, joblib, random, re, unicodedata, uuid, torch, csv |
|
|
import numpy as np |
|
|
os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0" |
|
|
import streamlit as st |
|
|
import datetime as dt |
|
|
from pathlib import Path |
|
|
import torch |
|
|
import numpy as np |
|
|
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForSequenceClassification |
|
|
from huggingface_hub import hf_hub_download |
|
|
from sentence_transformers import SentenceTransformer |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
HF_TOKEN = os.environ.get("HF_TOKEN") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def sidebar_params(): |
|
|
|
|
|
with st.sidebar: |
|
|
st.title("🎮 Adjustments (T5-Base)") |
|
|
|
|
|
ss = st.session_state |
|
|
|
|
|
|
|
|
|
|
|
ss = st.session_state |
|
|
if "show_llm_controls" not in ss: |
|
|
ss.show_llm_controls = False |
|
|
|
|
|
|
|
|
ss.setdefault("persona", "Normal") |
|
|
ss.setdefault("mode", "beam") |
|
|
ss.setdefault("max_new", 128) |
|
|
ss.setdefault("min_tok", 16) |
|
|
ss.setdefault("no_repeat", 3) |
|
|
ss.setdefault("num_beams", 4) |
|
|
ss.setdefault("length_penalty", 1.0) |
|
|
ss.setdefault("temperature", 0.7) |
|
|
ss.setdefault("top_p", 0.9) |
|
|
ss.setdefault("repetition_penalty", 1.0) |
|
|
ss.setdefault("show_llm_controls", True) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
st.header("💡 Predefined Personalities") |
|
|
c1, c2 = st.columns(2) |
|
|
|
|
|
with c1: |
|
|
if st.button("Normal 🧐", use_container_width=True): |
|
|
ss.update({ |
|
|
"persona": "Normal", |
|
|
"mode": "beam", |
|
|
"num_beams": 1, |
|
|
"max_new": 92, |
|
|
"min_tok": 32, |
|
|
"no_repeat": 3, |
|
|
"length_penalty": .3, |
|
|
"temperature": 0.4, |
|
|
"top_p": 0.9, |
|
|
"repetition_penalty": .4, |
|
|
}) |
|
|
st.rerun() |
|
|
|
|
|
with c2: |
|
|
if st.button("Enthusiastic 😃", use_container_width=True): |
|
|
ss.update({ |
|
|
"persona": "Enthusiastic", |
|
|
"mode": "sampling", |
|
|
"max_new": 192, |
|
|
"min_tok": 48, |
|
|
"no_repeat": 3, |
|
|
"temperature": .8, |
|
|
"top_p": 0.95, |
|
|
"repetition_penalty": 1.0, |
|
|
}) |
|
|
st.rerun() |
|
|
|
|
|
st.caption(f"Selected Personality: **{ss.persona}**") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if st.button(("🔼 Hide" if ss.show_llm_controls else "🔽 Show") + " Advanced Settings"): |
|
|
ss.show_llm_controls = not ss.show_llm_controls |
|
|
st.rerun() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if ss.show_llm_controls: |
|
|
st.header("⚙️ Manual Adjustments") |
|
|
st.subheader("📝 Text Generation") |
|
|
picked = st.radio( |
|
|
"Strategy", |
|
|
["Beam search (stable)", "Sampling (creative)"], |
|
|
index=0 if ss.mode == "beam" else 1, |
|
|
help="https://huggingface.co/docs/transformers/generation_strategies" |
|
|
) |
|
|
ss.mode = "beam" if picked.startswith("Beam") else "sampling" |
|
|
|
|
|
st.subheader("🔧 LLM text generation parameters") |
|
|
ss.max_new = st.slider( |
|
|
"max_new_tokens", 16, 256, int(ss.max_new), step=8, |
|
|
help="https://huggingface.co/docs/transformers/main_classes/text_generation" |
|
|
) |
|
|
ss.min_tok = st.slider( |
|
|
"min_tokens", 0, int(ss.max_new), int(ss.min_tok), |
|
|
help="https://huggingface.co/docs/transformers/main_classes/text_generation" |
|
|
) |
|
|
ss.no_repeat = st.slider( |
|
|
"no_repeat_ngram_size", 0, 6, int(ss.no_repeat), |
|
|
help="https://huggingface.co/docs/transformers/main_classes/text_generation" |
|
|
) |
|
|
|
|
|
|
|
|
if ss.mode == "beam": |
|
|
ss.num_beams = st.slider( |
|
|
"num_beams", 2, 8, int(ss.num_beams), |
|
|
help="https://huggingface.co/docs/transformers/main_classes/text_generation" |
|
|
) |
|
|
ss.length_penalty = st.slider( |
|
|
"length_penalty", 0.0, 2.0, float(ss.length_penalty), |
|
|
step=0.1, help="https://huggingface.co/docs/transformers/main_classes/text_generation" |
|
|
) |
|
|
else: |
|
|
ss.temperature = st.slider( |
|
|
"temperature", 0.1, 1.5, float(ss.temperature), |
|
|
step=0.05, help="https://huggingface.co/docs/transformers/main_classes/text_generation" |
|
|
) |
|
|
ss.top_p = st.slider( |
|
|
"top_p", 0.5, 1.0, float(ss.top_p), |
|
|
step=0.01, help="https://huggingface.co/docs/transformers/main_classes/text_generation" |
|
|
) |
|
|
|
|
|
|
|
|
if "last_prompt" in st.session_state and st.session_state["last_prompt"]: |
|
|
with st.expander("Show generated prompt"): |
|
|
st.text_area( |
|
|
"Prompt actual:", |
|
|
st.session_state["last_prompt"], |
|
|
height=200, |
|
|
disabled=True |
|
|
) |
|
|
else: |
|
|
st.caption("👉 No prompt is available yet.") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
params = { |
|
|
"persona": ss.persona, |
|
|
"mode": ss.mode, |
|
|
"max_new_tokens": int(ss.max_new), |
|
|
"min_tokens": int(ss.min_tok), |
|
|
"no_repeat_ngram_size": int(ss.no_repeat), |
|
|
"repetition_penalty": float(ss.repetition_penalty), |
|
|
} |
|
|
if ss.mode == "beam": |
|
|
params.update({ |
|
|
"num_beams": int(ss.num_beams), |
|
|
"length_penalty": float(ss.length_penalty), |
|
|
}) |
|
|
else: |
|
|
params.update({ |
|
|
"temperature": float(ss.temperature), |
|
|
"top_p": float(ss.top_p), |
|
|
}) |
|
|
|
|
|
return params |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def truncate_sentences(text: str, max_sentences: int = 4) -> str: |
|
|
_SENT_SPLIT = re.compile(r'(?<=[\.\!\?…])\s+') |
|
|
s = text.strip() |
|
|
if not s: return s |
|
|
parts = _SENT_SPLIT.split(s) |
|
|
cut = " ".join(parts[:max_sentences]).strip() |
|
|
if cut and cut[-1] not in ".!?…": cut += "." |
|
|
return cut |
|
|
|
|
|
|
|
|
def _load_json_safe(path: Path, fallback: dict) -> dict: |
|
|
try: |
|
|
with open(path, "r", encoding="utf-8") as f: |
|
|
return json.load(f) |
|
|
except Exception: |
|
|
return fallback |
|
|
|
|
|
|
|
|
def limpiar_input(): |
|
|
st.session_state["entrada"] = "" |
|
|
|
|
|
|
|
|
def get_model_path(folder_name): |
|
|
return Path("Models") / folder_name |
|
|
|
|
|
|
|
|
def saving_interaction(question, response, context, user_id): |
|
|
''' |
|
|
inputs: |
|
|
question --> User input question |
|
|
response --> Assistant response to the user question |
|
|
context --> Context related to the user input, found by the trained classifier |
|
|
user_id --> ID for the current user (Unique ID per session) |
|
|
''' |
|
|
timestamp = dt.datetime.now().isoformat() |
|
|
stats_dir = Path("Statistics") |
|
|
stats_dir.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
archivo_csv = stats_dir / "conversaciones_log.csv" |
|
|
existe_csv = archivo_csv.exists() |
|
|
|
|
|
with open(archivo_csv, mode="a", encoding="utf-8", newline="") as f_csv: |
|
|
writer = csv.writer(f_csv) |
|
|
if not existe_csv: |
|
|
writer.writerow(["timestamp", "user_id", "contexto", "pregunta", "respuesta"]) |
|
|
writer.writerow([timestamp, user_id, context, question, response]) |
|
|
|
|
|
archivo_jsonl = stats_dir / "conversaciones_log.jsonl" |
|
|
with open(archivo_jsonl, mode="a", encoding="utf-8") as f_jsonl: |
|
|
registro = { |
|
|
"timestamp": timestamp, |
|
|
"user_id": user_id, |
|
|
"context": context, |
|
|
"pregunta": question, |
|
|
"respuesta": response} |
|
|
f_jsonl.write(json.dumps(registro, ensure_ascii=False) + "\n") |
|
|
|
|
|
|
|
|
@st.cache_resource |
|
|
def load_model(path_str): |
|
|
path = Path(path_str).resolve() |
|
|
tokenizer = AutoTokenizer.from_pretrained(path, local_files_only=True) |
|
|
model = AutoModelForSeq2SeqLM.from_pretrained(path, local_files_only=True) |
|
|
return model, tokenizer |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def polish_spanish(s: str) -> str: |
|
|
s = unicodedata.normalize("NFC", s).strip() |
|
|
s = re.sub(r'\s*[\[\(]\s*Assistant\s+(?:Social|T[eé]nico|T[eé]cnico)\s*[\]\)]\s*', '', s, flags=re.I) |
|
|
fixes = [ |
|
|
(r'(?i)(^|\W)T\s+puedes(?P<p>[^\w]|$)', r'\1Tú puedes\g<p>'), |
|
|
(r'(?i)(^|\W)T\s+(ya|eres|estas|estás|tienes|puedes)\b', r'\1Tú \2'), |
|
|
(r'(?i)\bclaro que s(?:i|í)?\b(?P<p>[,.\!?…])?', r'Claro que sí\g<p>'), |
|
|
(r'(?i)(^|\s)si,', r'\1Sí,'), |
|
|
(r'(?i)(\beso\s+)s(\s+est[áa]\b)', r'\1sí\2'), |
|
|
(r'(?i)(^|[\s,;:])s(\s+es\b)', r'\1sí\2'), |
|
|
(r'(?i)\btiles\b', 'útiles'), |
|
|
(r'(?i)\butiles\b', 'útiles'), |
|
|
(r'(?i)\butil\b', 'útil'), |
|
|
(r'(?i)\baqui\b', 'aquí'), |
|
|
(r'(?i)\baqu\b(?=\s+estoy\b)', 'aquí'), |
|
|
(r'(?i)\balgn\b', 'algún'), |
|
|
(r'(?i)\balgun\b', 'algún'), |
|
|
(r'(?i)\bAnimo\b', 'Ánimo'), |
|
|
(r'(?i)\bcario\b', 'cariño'), |
|
|
(r'(?i)\baprendisaje\b', 'aprendizaje'), |
|
|
(r'(?i)\bmanana\b', 'mañana'), |
|
|
(r'(?i)\bmaana\b', 'mañana'), |
|
|
(r'(?i)\benergia\b', 'energía'), |
|
|
(r'(?i)\benerga\b', 'energía'), |
|
|
(r'(?i)\bextrano\b', 'extraño'), |
|
|
(r'(?i)\bextrana\b', 'extraña'), |
|
|
(r'(?i)\bextranar\b', 'extrañar'), |
|
|
(r'(?i)\bextranarte\b', 'extrañarte'), |
|
|
(r'(?i)\bextranas\b', 'extrañas'), |
|
|
(r'(?i)\bextranos\b', 'extraños'), |
|
|
(r'(?i)\baqu\b', 'aquí'), |
|
|
(r'(?i)\baqui\b', 'aquí'), |
|
|
(r'(?i)\bestare\b', 'estaré'), |
|
|
(r'(?i)\bclarn\b', 'clarín'), |
|
|
(r'(?i)\bclarin\b', 'clarín'), |
|
|
(r'(?i)\bclar[íi]n\s+cornetas\b', 'clarín cornetas'), |
|
|
(r'(?i)(^|\s)s([,.;:!?])', r'\1Sí\2'), |
|
|
(r'(?i)\bfutbol\b', 'fútbol'), |
|
|
(r'(?i)(^|\s)as(\s+se\b)', r'\1Así\2'), |
|
|
(r'(?i)(^|\s)s(\s+orientarte\b)', r'\1sí\2'), |
|
|
(r'(?i)\bbuen dia\b', 'buen día'), |
|
|
(r'(?i)\bgran dia\b', 'gran día'), |
|
|
(r'(?i)\bdias\b', 'días'), |
|
|
(r'(?i)\bdia\b', 'día'), |
|
|
(r'(?i)\bgran da\b', 'gran día'), |
|
|
(r'(?i)\bacompa?a(r|rte|do|da|dos|das)?\b', r'acompaña\1'), |
|
|
(r'(?i)(^|\s)as([,.;:!?]|\s|$)', r'\1así\2'), |
|
|
(r'(?i)(^|\s)S lo se\b', r'\1Sí lo sé'), |
|
|
(r'(?i)(^|\s)S lo sé\b', r'\1Sí lo sé'), |
|
|
(r'(?i)\bcudese\b', 'cuídese'), |
|
|
(r'(?i)\bpequeo\b', 'pequeño'), |
|
|
(r'(?i)\bpequea\b', 'pequeña'), |
|
|
(r'(?i)\bpequeos\b', 'pequeños'), |
|
|
(r'(?i)\bpequeas\b', 'pequeñas'), |
|
|
(r'(?i)\bunico\b', 'único'), |
|
|
(r'(?i)\bunica\b', 'única'), |
|
|
(r'(?i)\bunicos\b', 'únicos'), |
|
|
(r'(?i)\bunicas\b', 'únicas'), |
|
|
(r'(?i)\bnico\b', 'único'), |
|
|
(r'(?i)\bnica\b', 'única'), |
|
|
(r'(?i)\bnicos\b', 'únicos'), |
|
|
(r'(?i)\bnicas\b', 'únicas'), |
|
|
(r'(?i)\bestadstico\b', 'estadístico'), |
|
|
(r'(?i)\bestadstica\b', 'estadística'), |
|
|
(r'(?i)\bestadsticos\b', 'estadísticos'), |
|
|
(r'(?i)\bestadsticas\b', 'estadísticas'), |
|
|
(r'(?i)\bcudate\b', 'cuídate'), |
|
|
(r'(?i)\bcuidate\b', 'cuídate'), |
|
|
(r'(?i)\bcuidese\b', 'cuídese'), |
|
|
(r'(?i)\bcudese\b', 'cuídese'), |
|
|
(r'(?i)\bcuidense\b', 'cuídense'), |
|
|
(r'(?i)\bcudense\b', 'cuídense'), |
|
|
(r'(?i)\bgracias por confiar en m\b', 'gracias por confiar en mí'), |
|
|
(r'(?i)\bcada dia\b', 'cada día'), |
|
|
(r'(?i)\bcada da\b', 'cada día'), |
|
|
(r'(?i)\bsegun\b', 'según'), |
|
|
(r'(?i)\bcaracteristica(s)?\b', r'característica\1'), |
|
|
(r'(?i)\bcaracterstica(s)?\b', r'característica\1'), |
|
|
(r'(?i)\b([a-záéíóúñ]+)cion\b', r'\1ción'), |
|
|
(r'(?i)\bdeterminacio\b', 'determinación'), |
|
|
] |
|
|
for pat, rep in fixes: |
|
|
s = re.sub(pat, rep, s) |
|
|
|
|
|
s = re.sub(r'(?i)^eso es todo!(?P<r>(\s|$).*)', r'¡Eso es todo!\g<r>', s) |
|
|
|
|
|
def add_opening_q(m): |
|
|
cuerpo = m.group('qbody') |
|
|
if '¿' in cuerpo: |
|
|
return m.group(0) |
|
|
return f"{m.group('pre')}¿{cuerpo}" |
|
|
s = re.sub(r'(?P<pre>(^|[\.!\…]\s+))(?P<qbody>[^?]*\?)', add_opening_q, s) |
|
|
|
|
|
def _open_exclam(m): |
|
|
palabra = m.group('w') |
|
|
resto = m.group('r') or '' |
|
|
return f'¡{palabra}!{resto}' |
|
|
s = re.sub(r'(?i)^(?P<w>(hola|gracias|genial|perfecto|claro|por supuesto|con gusto|listo|vaya|wow|tu puedes|tú puedes|clarín|clarin|clarín cornetas))!(?P<r>(\s|$).*)',_open_exclam, s) |
|
|
|
|
|
s = re.sub(r'\s+', ' ', s).strip() |
|
|
if s and s[-1] not in ".!?…": |
|
|
s += "." |
|
|
return s |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def anti_echo(response: str, user_text: str) -> str: |
|
|
rn = normalize_for_route(response) |
|
|
un = normalize_for_route(user_text) |
|
|
def _clean_leading(s: str) -> str: |
|
|
s = re.sub(r'^\s*[,;:\-–—]\s*', '', s) |
|
|
s = re.sub(r'^\s+', '', s) |
|
|
return s |
|
|
if len(un) >= 4 and rn.startswith(un): |
|
|
cut = re.sub(r'^\s*[^,;:\.\!\?]{0,120}[,;:\-]\s*', '', response).lstrip() |
|
|
if cut and cut != response: |
|
|
return _clean_leading(cut) |
|
|
return _clean_leading(response[len(user_text):]) |
|
|
return response |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def normalize_for_route(s: str) -> str: |
|
|
s = unicodedata.normalize("NFKD", s) |
|
|
s = "".join(ch for ch in s if not unicodedata.combining(ch)) |
|
|
s = re.sub(r"[^\w\s-]", " ", s, flags=re.UNICODE) |
|
|
s = re.sub(r"\s+", " ", s).strip().lower() |
|
|
return s |
|
|
|
|
|
_Q_STARTERS = { |
|
|
"como","que","quien","quienes","cuando","donde","por que","para que", |
|
|
"cual","cuales","cuanto","cuantos","cuanta","cuantas" |
|
|
} |
|
|
_EXC_TRIGGERS = {"motiva","motivame","animate","animame","animo","ayudame","ayudame porfa", "clarin", "clarín", "clarinete", "clarin cornetas"} |
|
|
SPECIAL_NOPUNCT = {"kiubo", "quiubo", "que chido", "qué chido", "que buena onda"} |
|
|
_Q_VERB_STARTERS = {"eres","estas","estás","puedes","sabes","tienes","quieres","conoces", |
|
|
"crees","piensas","dirias","dirías","podrias","podrías","podras","podrás"} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def needs_question_marks(norm: str) -> bool: |
|
|
if "?" in norm: return False |
|
|
for w in _Q_STARTERS: |
|
|
if norm.startswith(w + " ") or norm == w: |
|
|
return True |
|
|
return False |
|
|
|
|
|
def needs_exclam(norm: str) -> bool: |
|
|
if "!" in norm: return False |
|
|
return any(t in norm for t in _EXC_TRIGGERS) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def is_slang_greeting(norm: str) -> bool: |
|
|
SHORT = { |
|
|
"que pex", "que onda", "ke pex", "k pex", "q onda", |
|
|
"kiubo", "quiubo", "quiubole", "quiúbole", "kionda", "q onda", "k onda", |
|
|
"que rollo", "ke onda", "que show", "que tranza" |
|
|
} |
|
|
if norm in SHORT: return True |
|
|
if re.match(r"^(q|k|ke|que)\s+(pex|onda|rollo|show|tranza)\b", norm): return True |
|
|
if re.match(r"^(kiubo|quiubo|quiubole|quiúbole|quiubol[e]?)\b", norm): return True |
|
|
return False |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def capitalize_spanish(s: str) -> str: |
|
|
s = s.strip() |
|
|
i = 0 |
|
|
while i < len(s) and not s[i].isalpha(): |
|
|
i += 1 |
|
|
if i < len(s): |
|
|
s = s[:i] + s[i].upper() + s[i+1:] |
|
|
return s |
|
|
|
|
|
def smart_autopunct(user_text: str) -> str: |
|
|
s = user_text.strip() |
|
|
if len(s) > 20: |
|
|
return capitalize_spanish(s) |
|
|
norm = normalize_for_route(s) |
|
|
if norm in SPECIAL_NOPUNCT: |
|
|
s = re.sub(r'[¿?!¡]+', '', s).strip() |
|
|
return capitalize_spanish(s) |
|
|
if norm.startswith("y si "): |
|
|
s = f"¿{s}?" |
|
|
return capitalize_spanish(s) |
|
|
if "?" in s and "¿" not in s: |
|
|
s = "¿" + s |
|
|
return capitalize_spanish(s) |
|
|
if "!" in s and "¡" not in s: |
|
|
s = "¡" + s |
|
|
return capitalize_spanish(s) |
|
|
if is_slang_greeting(norm): |
|
|
s = f"¡{s}!" |
|
|
return capitalize_spanish(s) |
|
|
if needs_question_marks(norm): |
|
|
s = f"¿{s}?" |
|
|
return capitalize_spanish(s) |
|
|
toks = norm.split() |
|
|
if toks and toks[0] in _Q_VERB_STARTERS: |
|
|
s = f"¿{s}?" |
|
|
return capitalize_spanish(s) |
|
|
if re.match(r"^(me\s+ayudas?|me\s+puedes|podrias?|podras?)\b", norm): |
|
|
s = f"¿{s}?" |
|
|
return capitalize_spanish(s) |
|
|
if needs_exclam(norm): |
|
|
s = f"¡{s}!" |
|
|
return capitalize_spanish(s) |
|
|
return capitalize_spanish(s) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def set_seeds(seed: int = 42): |
|
|
random.seed(seed); np.random.seed(seed); torch.manual_seed(seed) |
|
|
if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed) |
|
|
torch.backends.cudnn.deterministic = True |
|
|
torch.backends.cudnn.benchmark = False |
|
|
|
|
|
|
|
|
|
|
|
def persona_style_prompt(persona: str, domain: str) -> str: |
|
|
"""Instrucción breve de estilo según personalidad y dominio (technical/social).""" |
|
|
if persona == "Enthusiastic": |
|
|
return ( |
|
|
"Responde de forma creativa, usa al menos 232 palabras. ") |
|
|
if persona == "Normal": |
|
|
return "" |
|
|
return "" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def classify_context(question, label_classes, model, tokenizer, device): |
|
|
model = model.to(device) |
|
|
inputs = tokenizer(question, return_tensors="pt", padding=True, truncation=True, max_length=128) |
|
|
inputs = {k: v.to(device) for k, v in inputs.items()} |
|
|
with torch.no_grad(): |
|
|
outputs = model(**inputs) |
|
|
logits = outputs.logits |
|
|
pred_intent = torch.argmax(logits, dim=1).item() |
|
|
predicted_label = label_classes[pred_intent] |
|
|
return predicted_label |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def technical_asnwer(question, context, model, tokenizer, device, gen_params=None): |
|
|
model = model.to(device).eval() |
|
|
persona_name = (gen_params or {}).get("persona", st.session_state.get("persona", "Normal")) |
|
|
style = persona_style_prompt(persona_name, "technical") |
|
|
|
|
|
|
|
|
input_text = f"{style}Context: {context} [SEP] Question: {question}." |
|
|
|
|
|
st.session_state["last_prompt"] = input_text |
|
|
st.session_state["just_generated"] = True |
|
|
|
|
|
enc = tokenizer(input_text, return_tensors="pt", padding=True, truncation=True, max_length=256).to(device) |
|
|
|
|
|
bad_words = ["["] |
|
|
bad_ids = [tokenizer(bw, add_special_tokens=False).input_ids for bw in bad_words] |
|
|
|
|
|
|
|
|
max_new = int((gen_params).get("max_new_tokens")) |
|
|
min_new = int((gen_params).get("min_tokens")) |
|
|
no_repeat = int((gen_params).get("no_repeat_ngram_size")) |
|
|
rep_pen = float((gen_params).get("repetition_penalty")) |
|
|
mode = (gen_params or {}).get("mode", "beam") |
|
|
|
|
|
if mode == "sampling": |
|
|
temperature = float((gen_params or {}).get("temperature", 0.7)) |
|
|
top_p = float((gen_params or {}).get("top_p", 0.9)) |
|
|
kwargs = dict( |
|
|
do_sample=True, |
|
|
num_beams=1, |
|
|
temperature=max(0.1, temperature), |
|
|
top_p=min(1.0, max(0.5, top_p)), |
|
|
max_new_tokens=max_new, |
|
|
min_new_tokens=max(0, min_new), |
|
|
no_repeat_ngram_size=no_repeat, |
|
|
repetition_penalty=max(1.0, rep_pen), |
|
|
bad_words_ids=bad_ids, |
|
|
eos_token_id=tokenizer.eos_token_id, |
|
|
pad_token_id=tokenizer.pad_token_id, |
|
|
) |
|
|
else: |
|
|
num_beams = max(2, int((gen_params or {}).get("num_beams", 4))) |
|
|
length_penalty = float((gen_params or {}).get("length_penalty", 1.0)) |
|
|
kwargs = dict( |
|
|
do_sample=False, |
|
|
num_beams=num_beams, |
|
|
length_penalty=length_penalty, |
|
|
max_new_tokens=max_new, |
|
|
min_new_tokens=max(0, min_new), |
|
|
no_repeat_ngram_size=no_repeat, |
|
|
repetition_penalty=max(1.0, rep_pen), |
|
|
bad_words_ids=bad_ids, |
|
|
eos_token_id=tokenizer.eos_token_id, |
|
|
pad_token_id=tokenizer.pad_token_id, |
|
|
) |
|
|
|
|
|
out_ids = model.generate( |
|
|
input_ids=enc["input_ids"], attention_mask=enc["attention_mask"], **kwargs |
|
|
) |
|
|
text = tokenizer.decode(out_ids[0], skip_special_tokens=True) |
|
|
|
|
|
if persona_name == "Normal": |
|
|
text = truncate_sentences(text, max_sentences=1) |
|
|
|
|
|
st.session_state["last_response"] = text |
|
|
|
|
|
|
|
|
|
|
|
return polish_spanish(text) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def social_asnwer(question, model, tokenizer, device, gen_params=None, block_web=True): |
|
|
|
|
|
model = model.to(device).eval() |
|
|
persona_name = (gen_params or {}).get("persona", st.session_state.get("persona", "Normal")) |
|
|
prompt_type = st.session_state.get("prompt_type", "Zero-shot") |
|
|
prompt = question |
|
|
|
|
|
st.session_state["last_prompt"] = prompt |
|
|
st.session_state["just_generated"] = True |
|
|
|
|
|
|
|
|
enc = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True, max_length=192).to(device) |
|
|
|
|
|
bad_words = ["[", "Thanks", "thank you"] |
|
|
if block_web: |
|
|
bad_words += ["website", "http", "www", ".com"] |
|
|
bad_ids = [tokenizer(bw, add_special_tokens=False).input_ids for bw in bad_words] |
|
|
|
|
|
|
|
|
max_new = int((gen_params).get("max_new_tokens")) |
|
|
min_tokens = int((gen_params).get("min_tokens")) |
|
|
min_length = int(enc["input_ids"].shape[1]) + max(0, min_tokens) |
|
|
no_repeat = int((gen_params).get("no_repeat_ngram_size")) |
|
|
rep_pen = float((gen_params).get("repetition_penalty")) |
|
|
mode = (gen_params or {}).get("mode", "beam") |
|
|
|
|
|
if mode == "sampling": |
|
|
temperature = float((gen_params or {}).get("temperature", 0.7)) |
|
|
top_p = float((gen_params or {}).get("top_p", 0.9)) |
|
|
kwargs = dict( |
|
|
do_sample=True, num_beams=1, |
|
|
temperature=max(0.1, temperature), |
|
|
top_p=min(1.0, max(0.5, top_p)), |
|
|
max_new_tokens=max_new, |
|
|
|
|
|
min_new_tokens=max(0, min_tokens), |
|
|
no_repeat_ngram_size=no_repeat, |
|
|
repetition_penalty=max(1.0, rep_pen), |
|
|
bad_words_ids=bad_ids, |
|
|
eos_token_id=tokenizer.eos_token_id, |
|
|
pad_token_id=tokenizer.pad_token_id, |
|
|
) |
|
|
else: |
|
|
num_beams = max(2, int((gen_params or {}).get("num_beams", 4))) |
|
|
length_penalty = float((gen_params or {}).get("length_penalty", 1.0)) |
|
|
kwargs = dict( |
|
|
do_sample=False, num_beams=num_beams, length_penalty=length_penalty, |
|
|
max_new_tokens=max_new, |
|
|
|
|
|
min_new_tokens=max(0, min_tokens), |
|
|
no_repeat_ngram_size=no_repeat, |
|
|
repetition_penalty=max(1.0, rep_pen), |
|
|
bad_words_ids=bad_ids, |
|
|
eos_token_id=tokenizer.eos_token_id, |
|
|
pad_token_id=tokenizer.pad_token_id, |
|
|
|
|
|
) |
|
|
|
|
|
out_ids = model.generate( |
|
|
input_ids=enc["input_ids"], attention_mask=enc["attention_mask"], **kwargs |
|
|
) |
|
|
text = tokenizer.decode(out_ids[0], skip_special_tokens=True) |
|
|
if persona_name == "Normal": |
|
|
text = truncate_sentences(text, max_sentences=2) |
|
|
|
|
|
text = polish_spanish(text) |
|
|
text = capitalize_spanish(text) |
|
|
|
|
|
st.session_state["last_response"] = text |
|
|
|
|
|
|
|
|
|
|
|
return text |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def rule_intent_override(user_text: str, predicted_label: str) -> str: |
|
|
n = normalize_for_route(user_text) |
|
|
if re.fullmatch(r"(motivame|motiva|animame|animo|ayudame|que tranza|qué tranza|que tranza)", n): |
|
|
return "social" |
|
|
return predicted_label |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def contextual_asnwer(question, label_classes, context_model, cont_tok, |
|
|
tec_model, tec_tok, soc_model, soc_tok, device, gen_params=None, block_web=True): |
|
|
context = classify_context(question, label_classes, context_model, cont_tok, device) |
|
|
context = rule_intent_override(question, context) |
|
|
|
|
|
context_icons = { |
|
|
"social": "💬", "modelos": "🔧", "evaluación": "📏", "optimización": "⚙️", |
|
|
"visualización": "📈", "aprendizaje": "🧠", "vida digital": "🧑💻", |
|
|
"estadística": "📊", "infraestructura": "🖥", "datos": "📂", "transformación digital": "🌀"} |
|
|
icon = context_icons.get(context, "🧠") |
|
|
|
|
|
if gen_params and "seed" in gen_params: |
|
|
set_seeds(gen_params["seed"]) |
|
|
|
|
|
if context == "social": |
|
|
return social_asnwer(question, soc_model, soc_tok, device, gen_params=gen_params, block_web=block_web), context |
|
|
else: |
|
|
return technical_asnwer(question, context, tec_model, tec_tok, device, gen_params=gen_params), context |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
|
|
|
|
|
|
ss = st.session_state |
|
|
ss.setdefault("historial", []) |
|
|
ss.setdefault("last_prompt", "") |
|
|
ss.setdefault("last_response", "") |
|
|
ss.setdefault("just_generated", False) |
|
|
|
|
|
|
|
|
GEN_PARAMS = sidebar_params() |
|
|
GEN_PARAMS["persona"] = st.session_state.persona |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if "user_id" not in st.session_state: |
|
|
st.session_state["user_id"] = str(uuid.uuid4())[:8] |
|
|
|
|
|
|
|
|
labels_path = hf_hub_download(repo_id="tecuhtli/assistant-classifier-bert", filename="context_labels.pkl", use_auth_token=HF_TOKEN) |
|
|
label_classes = joblib.load(labels_path) |
|
|
|
|
|
|
|
|
|
|
|
context_model = AutoModelForSequenceClassification.from_pretrained("tecuhtli/assistant-classifier-bert", use_auth_token=HF_TOKEN) |
|
|
cont_tok = AutoTokenizer.from_pretrained("tecuhtli/assistant-classifier-bert", use_auth_token=HF_TOKEN) |
|
|
|
|
|
|
|
|
tec_tok = AutoTokenizer.from_pretrained("tecuhtli/assistant-technical-t5", use_auth_token=HF_TOKEN) |
|
|
tec_model = AutoModelForSeq2SeqLM.from_pretrained("tecuhtli/assistant-technical-t5", use_auth_token=HF_TOKEN) |
|
|
|
|
|
|
|
|
soc_tok = AutoTokenizer.from_pretrained("tecuhtli/assistant-social-t5", use_auth_token=HF_TOKEN) |
|
|
soc_model = AutoModelForSeq2SeqLM.from_pretrained("tecuhtli/assistant-social-t5", use_auth_token=HF_TOKEN) |
|
|
|
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
|
|
|
st.title("🤖 Your Personal Assistant 🎓") |
|
|
|
|
|
st.caption("🙋🏽 You can ask me about technical concepts such as visualization, data cleaning, BI, and more.") |
|
|
st.caption("🙇🏽 I can *only* understand and answer in Spanish (🦅🇲🇽🌵).") |
|
|
st.caption("➡️ At this stage, I can respond to simple questions such as:") |
|
|
st.caption(" • ¿Cómo estás? • ¿Qué es...? • Explícame algo • Define algo • ¿Para qué sirve...?") |
|
|
|
|
|
st.caption("😊 If you want to know me better, visit: [hazutecuhtli.github.io](https://github.com/hazutecuhtli/LLMs_FineTuned_Chatbot)") |
|
|
|
|
|
st.markdown("<br>", unsafe_allow_html=True) |
|
|
|
|
|
st.caption("✏️ Type **'salir'** to exit.") |
|
|
|
|
|
|
|
|
if st.session_state.pop("_clear_entrada", False): |
|
|
if "entrada" in st.session_state: |
|
|
del st.session_state["entrada"] |
|
|
|
|
|
|
|
|
_flash = st.session_state.pop("_flash_response", None) |
|
|
|
|
|
|
|
|
with st.form("formulario_assistant"): |
|
|
user_question = st.text_area("📝 Escribe tu pregunta aquí", key="entrada", height=100) |
|
|
submitted = st.form_submit_button("Responder") |
|
|
|
|
|
if submitted: |
|
|
if not user_question: |
|
|
st.info("Chatbot: ¿Podrías repetir eso? No entendí bien 😅") |
|
|
else: |
|
|
response, context = contextual_asnwer( |
|
|
user_question, label_classes, context_model, cont_tok, |
|
|
tec_model, tec_tok, soc_model, soc_tok, device, |
|
|
gen_params=GEN_PARAMS, block_web=True, |
|
|
) |
|
|
|
|
|
|
|
|
hora_actual = dt.datetime.now().strftime("%Y-%m-%d %H:%M:%S") |
|
|
st.session_state.historial.append(("Tú", user_question, hora_actual)) |
|
|
|
|
|
hora_actual = dt.datetime.now().strftime("%Y-%m-%d %H:%M:%S") |
|
|
st.session_state.historial.append(("Assistant", response, hora_actual)) |
|
|
|
|
|
|
|
|
saving_interaction(user_question, response, context, st.session_state["user_id"]) |
|
|
|
|
|
|
|
|
st.session_state["_flash_response"] = response |
|
|
|
|
|
|
|
|
st.session_state["_clear_entrada"] = True |
|
|
|
|
|
|
|
|
st.rerun() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if _flash: |
|
|
st.success(_flash) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if st.session_state.historial: |
|
|
st.markdown("---") |
|
|
|
|
|
|
|
|
lineas = [] |
|
|
for msg in reversed(st.session_state.historial): |
|
|
if len(msg) == 3: |
|
|
autor, texto, hora = msg |
|
|
lineas.append(f"[{hora}] {autor}: {texto}") |
|
|
else: |
|
|
autor, texto = msg |
|
|
lineas.append(f"{autor}: {texto}") |
|
|
texto_chat = "\n\n".join(lineas) |
|
|
|
|
|
st.download_button( |
|
|
label="💾 Descargar conversación como .txt", |
|
|
data=texto_chat, |
|
|
file_name="conversacion_assistant.txt", |
|
|
mime="text/plain", |
|
|
use_container_width=True |
|
|
) |
|
|
|
|
|
|
|
|
st.markdown( |
|
|
""" |
|
|
<div id="chat-container" style=" |
|
|
max-height: 400px; |
|
|
overflow-y: auto; |
|
|
padding: 10px; |
|
|
border: 1px solid #333; |
|
|
border-radius: 10px; |
|
|
background: linear-gradient(180deg, #0e0e0e 0%, #1b1b1b 100%); |
|
|
margin-top: 10px; |
|
|
"> |
|
|
""", |
|
|
unsafe_allow_html=True |
|
|
) |
|
|
|
|
|
for msg in reversed(st.session_state.historial): |
|
|
if len(msg) == 3: |
|
|
autor, texto, _ = msg |
|
|
else: |
|
|
autor, texto = msg |
|
|
|
|
|
if autor == "Tú": |
|
|
st.markdown( |
|
|
f""" |
|
|
<div style=" |
|
|
text-align: right; |
|
|
background-color: #2d2d2d; |
|
|
color: #e6e6e6; |
|
|
padding: 10px 14px; |
|
|
border-radius: 12px; |
|
|
margin: 6px 0; |
|
|
border: 1px solid #3a3a3a; |
|
|
display: inline-block; |
|
|
max-width: 80%; |
|
|
float: right; |
|
|
clear: both; |
|
|
"> |
|
|
🧍♂️ <b>{autor}:</b> {texto} |
|
|
</div> |
|
|
""", |
|
|
unsafe_allow_html=True |
|
|
) |
|
|
else: |
|
|
st.markdown( |
|
|
f""" |
|
|
<div style=" |
|
|
text-align: left; |
|
|
background-color: #162b1f; |
|
|
color: #d9ead3; |
|
|
padding: 10px 14px; |
|
|
border-radius: 12px; |
|
|
margin: 6px 0; |
|
|
border: 1px solid #264d36; |
|
|
display: inline-block; |
|
|
max-width: 80%; |
|
|
float: left; |
|
|
clear: both; |
|
|
"> |
|
|
🤖 <b>{autor}:</b> {texto} |
|
|
</div> |
|
|
""", |
|
|
unsafe_allow_html=True |
|
|
) |
|
|
|
|
|
st.markdown("</div>", unsafe_allow_html=True) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|