#*************************************************************************** # Mori (tech-only) — Streamlit App sin sidebar ni social, con RAG opcional #*************************************************************************** import os, sys, warnings, json, joblib, random, re, unicodedata, uuid, torch, csv import numpy as np os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0" import streamlit as st import datetime as dt from pathlib import Path import torch import numpy as np from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForSequenceClassification from huggingface_hub import hf_hub_download from sentence_transformers import SentenceTransformer # RAG embeddings # ========================= # Configuración general # ========================= HF_TOKEN = os.environ.get("HF_TOKEN") # Token privado (colócalo en Secrets o variable de entorno) #*************************************************************************** # Sidebar controls for generation params #*************************************************************************** def sidebar_params(): with st.sidebar: st.title("🎮 Adjustments (T5-Base)") ss = st.session_state # Defaults (solo 1ª vez) # Estado inicial: ocultar ajustes avanzados ss = st.session_state if "show_llm_controls" not in ss: ss.show_llm_controls = False ss.setdefault("persona", "Normal") ss.setdefault("mode", "beam") # 'beam' | 'sampling' ss.setdefault("max_new", 128) ss.setdefault("min_tok", 16) ss.setdefault("no_repeat", 3) ss.setdefault("num_beams", 4) ss.setdefault("length_penalty", 1.0) ss.setdefault("temperature", 0.7) ss.setdefault("top_p", 0.9) ss.setdefault("repetition_penalty", 1.0) ss.setdefault("show_llm_controls", True) # Toggle principal # ---------------------------- # Personalidad (presets) # ---------------------------- st.header("💡 Predefined Personalities") c1, c2 = st.columns(2) with c1: if st.button("Normal 🧐", use_container_width=True): ss.update({ "persona": "Normal", "mode": "beam", "num_beams": 1, "max_new": 92, "min_tok": 32, "no_repeat": 3, "length_penalty": .3, "temperature": 0.4, "top_p": 0.9, "repetition_penalty": .4, }) st.rerun() with c2: if st.button("Enthusiastic 😃", use_container_width=True): ss.update({ "persona": "Enthusiastic", # <- corregido "mode": "sampling", "max_new": 192, "min_tok": 48, "no_repeat": 3, "temperature": .8, "top_p": 0.95, "repetition_penalty": 1.0, }) st.rerun() st.caption(f"Selected Personality: **{ss.persona}**") # ---------------------------- # Botón para mostrar/ocultar parámetros # ---------------------------- if st.button(("🔼 Hide" if ss.show_llm_controls else "🔽 Show") + " Advanced Settings"): ss.show_llm_controls = not ss.show_llm_controls st.rerun() # ---------------------------- # Controles del modelo (sliders, estrategia, etc.) # ---------------------------- if ss.show_llm_controls: st.header("⚙️ Manual Adjustments") st.subheader("📝 Text Generation") picked = st.radio( "Strategy", ["Beam search (stable)", "Sampling (creative)"], index=0 if ss.mode == "beam" else 1, help="https://huggingface.co/docs/transformers/generation_strategies" ) ss.mode = "beam" if picked.startswith("Beam") else "sampling" st.subheader("🔧 LLM text generation parameters") ss.max_new = st.slider( "max_new_tokens", 16, 256, int(ss.max_new), step=8, help="https://huggingface.co/docs/transformers/main_classes/text_generation" ) ss.min_tok = st.slider( "min_tokens", 0, int(ss.max_new), int(ss.min_tok), help="https://huggingface.co/docs/transformers/main_classes/text_generation" ) ss.no_repeat = st.slider( "no_repeat_ngram_size", 0, 6, int(ss.no_repeat), help="https://huggingface.co/docs/transformers/main_classes/text_generation" ) # Subcontroles según modo if ss.mode == "beam": ss.num_beams = st.slider( "num_beams", 2, 8, int(ss.num_beams), help="https://huggingface.co/docs/transformers/main_classes/text_generation" ) ss.length_penalty = st.slider( "length_penalty", 0.0, 2.0, float(ss.length_penalty), step=0.1, help="https://huggingface.co/docs/transformers/main_classes/text_generation" ) else: ss.temperature = st.slider( "temperature", 0.1, 1.5, float(ss.temperature), step=0.05, help="https://huggingface.co/docs/transformers/main_classes/text_generation" ) ss.top_p = st.slider( "top_p", 0.5, 1.0, float(ss.top_p), step=0.01, help="https://huggingface.co/docs/transformers/main_classes/text_generation" ) if "last_prompt" in st.session_state and st.session_state["last_prompt"]: with st.expander("Show generated prompt"): st.text_area( "Prompt actual:", st.session_state["last_prompt"], height=200, disabled=True ) else: st.caption("👉 No prompt is available yet.") # ---------------------------- # Construir diccionario de parámetros # ---------------------------- params = { "persona": ss.persona, "mode": ss.mode, "max_new_tokens": int(ss.max_new), "min_tokens": int(ss.min_tok), "no_repeat_ngram_size": int(ss.no_repeat), "repetition_penalty": float(ss.repetition_penalty), } if ss.mode == "beam": params.update({ "num_beams": int(ss.num_beams), "length_penalty": float(ss.length_penalty), }) else: params.update({ "temperature": float(ss.temperature), "top_p": float(ss.top_p), }) return params #*************************************************************************** # Functions #*************************************************************************** def truncate_sentences(text: str, max_sentences: int = 4) -> str: _SENT_SPLIT = re.compile(r'(?<=[\.\!\?…])\s+') s = text.strip() if not s: return s parts = _SENT_SPLIT.split(s) cut = " ".join(parts[:max_sentences]).strip() if cut and cut[-1] not in ".!?…": cut += "." return cut def _load_json_safe(path: Path, fallback: dict) -> dict: try: with open(path, "r", encoding="utf-8") as f: return json.load(f) except Exception: return fallback # Function to clean the question field def limpiar_input(): st.session_state["entrada"] = "" # ✅ Corrige la ruta correctamente desde Scripts hacia Models def get_model_path(folder_name): return Path("Models") / folder_name # Function to save user interaction def saving_interaction(question, response, context, user_id): ''' inputs: question --> User input question response --> Assistant response to the user question context --> Context related to the user input, found by the trained classifier user_id --> ID for the current user (Unique ID per session) ''' timestamp = dt.datetime.now().isoformat() stats_dir = Path("Statistics") stats_dir.mkdir(parents=True, exist_ok=True) archivo_csv = stats_dir / "conversaciones_log.csv" existe_csv = archivo_csv.exists() with open(archivo_csv, mode="a", encoding="utf-8", newline="") as f_csv: writer = csv.writer(f_csv) if not existe_csv: writer.writerow(["timestamp", "user_id", "contexto", "pregunta", "respuesta"]) writer.writerow([timestamp, user_id, context, question, response]) archivo_jsonl = stats_dir / "conversaciones_log.jsonl" with open(archivo_jsonl, mode="a", encoding="utf-8") as f_jsonl: registro = { "timestamp": timestamp, "user_id": user_id, "context": context, "pregunta": question, "respuesta": response} f_jsonl.write(json.dumps(registro, ensure_ascii=False) + "\n") # Function to load models within the huggingface repositories space @st.cache_resource def load_model(path_str): path = Path(path_str).resolve() tokenizer = AutoTokenizer.from_pretrained(path, local_files_only=True) model = AutoModelForSeq2SeqLM.from_pretrained(path, local_files_only=True) return model, tokenizer #------------------------------------------------------------------------- # Function to correct Spanish sentences' punctuation and missing characters #------------------------------------------------------------------------- def polish_spanish(s: str) -> str: s = unicodedata.normalize("NFC", s).strip() s = re.sub(r'\s*[\[\(]\s*Assistant\s+(?:Social|T[eé]nico|T[eé]cnico)\s*[\]\)]\s*', '', s, flags=re.I) fixes = [ (r'(?i)(^|\W)T\s+puedes(?P

[^\w]|$)', r'\1Tú puedes\g

'), (r'(?i)(^|\W)T\s+(ya|eres|estas|estás|tienes|puedes)\b', r'\1Tú \2'), (r'(?i)\bclaro que s(?:i|í)?\b(?P

[,.\!?…])?', r'Claro que sí\g

'), (r'(?i)(^|\s)si,', r'\1Sí,'), (r'(?i)(\beso\s+)s(\s+est[áa]\b)', r'\1sí\2'), (r'(?i)(^|[\s,;:])s(\s+es\b)', r'\1sí\2'), (r'(?i)\btiles\b', 'útiles'), (r'(?i)\butiles\b', 'útiles'), (r'(?i)\butil\b', 'útil'), (r'(?i)\baqui\b', 'aquí'), (r'(?i)\baqu\b(?=\s+estoy\b)', 'aquí'), (r'(?i)\balgn\b', 'algún'), (r'(?i)\balgun\b', 'algún'), (r'(?i)\bAnimo\b', 'Ánimo'), (r'(?i)\bcario\b', 'cariño'), (r'(?i)\baprendisaje\b', 'aprendizaje'), (r'(?i)\bmanana\b', 'mañana'), (r'(?i)\bmaana\b', 'mañana'), (r'(?i)\benergia\b', 'energía'), (r'(?i)\benerga\b', 'energía'), (r'(?i)\bextrano\b', 'extraño'), (r'(?i)\bextrana\b', 'extraña'), (r'(?i)\bextranar\b', 'extrañar'), (r'(?i)\bextranarte\b', 'extrañarte'), (r'(?i)\bextranas\b', 'extrañas'), (r'(?i)\bextranos\b', 'extraños'), (r'(?i)\baqu\b', 'aquí'), (r'(?i)\baqui\b', 'aquí'), (r'(?i)\bestare\b', 'estaré'), (r'(?i)\bclarn\b', 'clarín'), (r'(?i)\bclarin\b', 'clarín'), (r'(?i)\bclar[íi]n\s+cornetas\b', 'clarín cornetas'), (r'(?i)(^|\s)s([,.;:!?])', r'\1Sí\2'), (r'(?i)\bfutbol\b', 'fútbol'), (r'(?i)(^|\s)as(\s+se\b)', r'\1Así\2'), (r'(?i)(^|\s)s(\s+orientarte\b)', r'\1sí\2'), (r'(?i)\bbuen dia\b', 'buen día'), (r'(?i)\bgran dia\b', 'gran día'), (r'(?i)\bdias\b', 'días'), (r'(?i)\bdia\b', 'día'), (r'(?i)\bgran da\b', 'gran día'), (r'(?i)\bacompa?a(r|rte|do|da|dos|das)?\b', r'acompaña\1'), (r'(?i)(^|\s)as([,.;:!?]|\s|$)', r'\1así\2'), (r'(?i)(^|\s)S lo se\b', r'\1Sí lo sé'), (r'(?i)(^|\s)S lo sé\b', r'\1Sí lo sé'), (r'(?i)\bcudese\b', 'cuídese'), (r'(?i)\bpequeo\b', 'pequeño'), (r'(?i)\bpequea\b', 'pequeña'), (r'(?i)\bpequeos\b', 'pequeños'), (r'(?i)\bpequeas\b', 'pequeñas'), (r'(?i)\bunico\b', 'único'), (r'(?i)\bunica\b', 'única'), (r'(?i)\bunicos\b', 'únicos'), (r'(?i)\bunicas\b', 'únicas'), (r'(?i)\bnico\b', 'único'), (r'(?i)\bnica\b', 'única'), (r'(?i)\bnicos\b', 'únicos'), (r'(?i)\bnicas\b', 'únicas'), (r'(?i)\bestadstico\b', 'estadístico'), (r'(?i)\bestadstica\b', 'estadística'), (r'(?i)\bestadsticos\b', 'estadísticos'), (r'(?i)\bestadsticas\b', 'estadísticas'), (r'(?i)\bcudate\b', 'cuídate'), (r'(?i)\bcuidate\b', 'cuídate'), (r'(?i)\bcuidese\b', 'cuídese'), (r'(?i)\bcudese\b', 'cuídese'), (r'(?i)\bcuidense\b', 'cuídense'), (r'(?i)\bcudense\b', 'cuídense'), (r'(?i)\bgracias por confiar en m\b', 'gracias por confiar en mí'), (r'(?i)\bcada dia\b', 'cada día'), (r'(?i)\bcada da\b', 'cada día'), (r'(?i)\bsegun\b', 'según'), (r'(?i)\bcaracteristica(s)?\b', r'característica\1'), (r'(?i)\bcaracterstica(s)?\b', r'característica\1'), (r'(?i)\b([a-záéíóúñ]+)cion\b', r'\1ción'), (r'(?i)\bdeterminacio\b', 'determinación'), ] for pat, rep in fixes: s = re.sub(pat, rep, s) s = re.sub(r'(?i)^eso es todo!(?P(\s|$).*)', r'¡Eso es todo!\g', s) def add_opening_q(m): cuerpo = m.group('qbody') if '¿' in cuerpo: return m.group(0) return f"{m.group('pre')}¿{cuerpo}" s = re.sub(r'(?P

(^|[\.!\…]\s+))(?P[^?]*\?)', add_opening_q, s)

    def _open_exclam(m):
        palabra = m.group('w')
        resto   = m.group('r') or ''
        return f'¡{palabra}!{resto}'
    s = re.sub(r'(?i)^(?P(hola|gracias|genial|perfecto|claro|por supuesto|con gusto|listo|vaya|wow|tu puedes|tú puedes|clarín|clarin|clarín cornetas))!(?P(\s|$).*)',_open_exclam, s)

    s = re.sub(r'\s+', ' ', s).strip()
    if s and s[-1] not in ".!?…":
        s += "."
    return s

#-------------------------------------------------------------------------
# Function to remove repeated input in the Model answer
#-------------------------------------------------------------------------

def anti_echo(response: str, user_text: str) -> str:
    rn = normalize_for_route(response)
    un = normalize_for_route(user_text)
    def _clean_leading(s: str) -> str:
        s = re.sub(r'^\s*[,;:\-–—]\s*', '', s)
        s = re.sub(r'^\s+', '', s)
        return s
    if len(un) >= 4 and rn.startswith(un):
        cut = re.sub(r'^\s*[^,;:\.\!\?]{0,120}[,;:\-]\s*', '', response).lstrip()
        if cut and cut != response:
            return _clean_leading(cut)
        return _clean_leading(response[len(user_text):])
    return response

#-------------------------------------------------------------------------
# Normalization helpers
#-------------------------------------------------------------------------

def normalize_for_route(s: str) -> str:
    s = unicodedata.normalize("NFKD", s)
    s = "".join(ch for ch in s if not unicodedata.combining(ch))
    s = re.sub(r"[^\w\s-]", " ", s, flags=re.UNICODE)
    s = re.sub(r"\s+", " ", s).strip().lower()
    return s

_Q_STARTERS = {
    "como","que","quien","quienes","cuando","donde","por que","para que",
    "cual","cuales","cuanto","cuantos","cuanta","cuantas"
}
_EXC_TRIGGERS = {"motiva","motivame","animate","animame","animo","ayudame","ayudame porfa", "clarin", "clarín", "clarinete", "clarin cornetas"}
SPECIAL_NOPUNCT = {"kiubo", "quiubo", "que chido", "qué chido", "que buena onda"}
_Q_VERB_STARTERS = {"eres","estas","estás","puedes","sabes","tienes","quieres","conoces",
        "crees","piensas","dirias","dirías","podrias","podrías","podras","podrás"}

#-------------------------------------------------------------------------
# Punctuation helpers
#-------------------------------------------------------------------------

def needs_question_marks(norm: str) -> bool:
    if "?" in norm: return False
    for w in _Q_STARTERS:
        if norm.startswith(w + " ") or norm == w:
            return True
    return False

def needs_exclam(norm: str) -> bool:
    if "!" in norm: return False
    return any(t in norm for t in _EXC_TRIGGERS)

#-------------------------------------------------------------------------
# Greetings detection
#-------------------------------------------------------------------------

def is_slang_greeting(norm: str) -> bool:
    SHORT = {
        "que pex", "que onda", "ke pex", "k pex", "q onda",
        "kiubo", "quiubo", "quiubole", "quiúbole", "kionda", "q onda", "k onda",
        "que rollo", "ke onda", "que show", "que tranza"
    }
    if norm in SHORT: return True
    if re.match(r"^(q|k|ke|que)\s+(pex|onda|rollo|show|tranza)\b", norm): return True
    if re.match(r"^(kiubo|quiubo|quiubole|quiúbole|quiubol[e]?)\b", norm): return True
    return False

#-------------------------------------------------------------------------
# Capitalization & autopunct
#-------------------------------------------------------------------------

def capitalize_spanish(s: str) -> str:
    s = s.strip()
    i = 0
    while i < len(s) and not s[i].isalpha():
        i += 1
    if i < len(s):
        s = s[:i] + s[i].upper() + s[i+1:]
    return s

def smart_autopunct(user_text: str) -> str:
    s = user_text.strip()
    if len(s) > 20:
        return capitalize_spanish(s)
    norm = normalize_for_route(s)
    if norm in SPECIAL_NOPUNCT:
        s = re.sub(r'[¿?!¡]+', '', s).strip()
        return capitalize_spanish(s)
    if norm.startswith("y si "):
        s = f"¿{s}?"
        return capitalize_spanish(s)
    if "?" in s and "¿" not in s:
        s = "¿" + s
        return capitalize_spanish(s)
    if "!" in s and "¡" not in s:
        s = "¡" + s
        return capitalize_spanish(s)
    if is_slang_greeting(norm):
        s = f"¡{s}!"
        return capitalize_spanish(s)
    if needs_question_marks(norm):
        s = f"¿{s}?"
        return capitalize_spanish(s)
    toks = norm.split()
    if toks and toks[0] in _Q_VERB_STARTERS:
        s = f"¿{s}?"
        return capitalize_spanish(s)
    if re.match(r"^(me\s+ayudas?|me\s+puedes|podrias?|podras?)\b", norm):
        s = f"¿{s}?"
        return capitalize_spanish(s)
    if needs_exclam(norm):
        s = f"¡{s}!"
        return capitalize_spanish(s)
    return capitalize_spanish(s)


#-------------------------------------------------------------------------
# Seeds & helpers
#-------------------------------------------------------------------------

def set_seeds(seed: int = 42):
    random.seed(seed); np.random.seed(seed); torch.manual_seed(seed)
    if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

# --- Personalidades (solo estilo en prompt; parámetros ya vienen del sidebar) ---

def persona_style_prompt(persona: str, domain: str) -> str:
    """Instrucción breve de estilo según personalidad y dominio (technical/social)."""
    if persona == "Enthusiastic":
        return (
            "Responde de forma creativa, usa al menos 232 palabras. ")
    if persona == "Normal":  # ya no se usa, pero por compatibilidad
        return ""
    return ""  # Assistant response

#-------------------------------------------------------------------------
# Classifier
#-------------------------------------------------------------------------

def classify_context(question, label_classes, model, tokenizer, device):
    model = model.to(device)
    inputs = tokenizer(question, return_tensors="pt", padding=True, truncation=True, max_length=128)
    inputs = {k: v.to(device) for k, v in inputs.items()}
    with torch.no_grad():
        outputs = model(**inputs)
        logits = outputs.logits
    pred_intent = torch.argmax(logits, dim=1).item()
    predicted_label = label_classes[pred_intent]
    return predicted_label

#-------------------------------------------------------------------------
# Chatbot response for technical contexts using a Hugging Face model
#-------------------------------------------------------------------------

def technical_asnwer(question, context, model, tokenizer, device, gen_params=None):
    model = model.to(device).eval()
    persona_name = (gen_params or {}).get("persona", st.session_state.get("persona", "Normal"))
    style = persona_style_prompt(persona_name, "technical")
    
    # Promp Engineering para ayudar al asistente a encontrar la mejor respuesta
    input_text = f"{style}Context: {context} [SEP] Question: {question}."
        
    st.session_state["last_prompt"] = input_text  # o prompt
    st.session_state["just_generated"] = True
    #st.rerun()
    enc = tokenizer(input_text, return_tensors="pt", padding=True, truncation=True, max_length=256).to(device)

    bad_words = ["["]
    bad_ids = [tokenizer(bw, add_special_tokens=False).input_ids for bw in bad_words]

    # --- construir kwargs de generación, SIN tocar nada por personalidad ---
    max_new   = int((gen_params).get("max_new_tokens"))
    min_new   = int((gen_params).get("min_tokens"))          # <- ahora SIEMPRE min_new_tokens
    no_repeat = int((gen_params).get("no_repeat_ngram_size"))
    rep_pen   = float((gen_params).get("repetition_penalty"))
    mode      = (gen_params or {}).get("mode", "beam")

    if mode == "sampling":
        temperature = float((gen_params or {}).get("temperature", 0.7))
        top_p       = float((gen_params or {}).get("top_p", 0.9))
        kwargs = dict(
            do_sample=True,
            num_beams=1,
            temperature=max(0.1, temperature),
            top_p=min(1.0, max(0.5, top_p)),
            max_new_tokens=max_new,
            min_new_tokens=max(0, min_new),   # 👈 consistente
            no_repeat_ngram_size=no_repeat,
            repetition_penalty=max(1.0, rep_pen),
            bad_words_ids=bad_ids,
            eos_token_id=tokenizer.eos_token_id,
            pad_token_id=tokenizer.pad_token_id,
        )
    else:
        num_beams      = max(2, int((gen_params or {}).get("num_beams", 4)))
        length_penalty = float((gen_params or {}).get("length_penalty", 1.0))
        kwargs = dict(
            do_sample=False,
            num_beams=num_beams,
            length_penalty=length_penalty,
            max_new_tokens=max_new,
            min_new_tokens=max(0, min_new),   # 👈 también aquí (no min_length)
            no_repeat_ngram_size=no_repeat,
            repetition_penalty=max(1.0, rep_pen),
            bad_words_ids=bad_ids,
            eos_token_id=tokenizer.eos_token_id,
            pad_token_id=tokenizer.pad_token_id,
        )

    out_ids = model.generate(
        input_ids=enc["input_ids"], attention_mask=enc["attention_mask"], **kwargs
    )
    text = tokenizer.decode(out_ids[0], skip_special_tokens=True)

    if persona_name == "Normal":
        text = truncate_sentences(text, max_sentences=1)

    st.session_state["last_response"] = text
    #st.rerun()

    
    return polish_spanish(text)

#-------------------------------------------------------------------------
# Chatbot response for social contexts using a Hugging Face model
#-------------------------------------------------------------------------

def social_asnwer(question, model, tokenizer, device, gen_params=None, block_web=True):
    
    model = model.to(device).eval()
    persona_name = (gen_params or {}).get("persona", st.session_state.get("persona", "Normal"))
    prompt_type  = st.session_state.get("prompt_type", "Zero-shot")
    prompt = question

    st.session_state["last_prompt"] = prompt  # o prompt
    st.session_state["just_generated"] = True

 
    enc = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True, max_length=192).to(device)

    bad_words = ["[", "Thanks", "thank you"]
    if block_web:
        bad_words += ["website", "http", "www", ".com"]
    bad_ids = [tokenizer(bw, add_special_tokens=False).input_ids for bw in bad_words]

    
    max_new = int((gen_params).get("max_new_tokens"))
    min_tokens = int((gen_params).get("min_tokens"))
    min_length = int(enc["input_ids"].shape[1]) + max(0, min_tokens)
    no_repeat = int((gen_params).get("no_repeat_ngram_size"))
    rep_pen = float((gen_params).get("repetition_penalty"))
    mode = (gen_params or {}).get("mode", "beam")

    if mode == "sampling":
        temperature = float((gen_params or {}).get("temperature", 0.7))
        top_p = float((gen_params or {}).get("top_p", 0.9))
        kwargs = dict(
            do_sample=True, num_beams=1,
            temperature=max(0.1, temperature),
            top_p=min(1.0, max(0.5, top_p)),
            max_new_tokens=max_new,
            #min_length=min_length,
            min_new_tokens=max(0, min_tokens),
            no_repeat_ngram_size=no_repeat,
            repetition_penalty=max(1.0, rep_pen),
            bad_words_ids=bad_ids,
            eos_token_id=tokenizer.eos_token_id,
            pad_token_id=tokenizer.pad_token_id,           
        )
    else:
        num_beams = max(2, int((gen_params or {}).get("num_beams", 4)))
        length_penalty = float((gen_params or {}).get("length_penalty", 1.0))
        kwargs = dict(
            do_sample=False, num_beams=num_beams, length_penalty=length_penalty,
            max_new_tokens=max_new,
            #min_length=min_length,
            min_new_tokens=max(0, min_tokens),   # <- usar min_new_tokens
            no_repeat_ngram_size=no_repeat,
            repetition_penalty=max(1.0, rep_pen),
            bad_words_ids=bad_ids,
            eos_token_id=tokenizer.eos_token_id,
            pad_token_id=tokenizer.pad_token_id,
            
        )

    out_ids = model.generate(
        input_ids=enc["input_ids"], attention_mask=enc["attention_mask"], **kwargs
    )
    text = tokenizer.decode(out_ids[0], skip_special_tokens=True)
    if persona_name == "Normal":
        text = truncate_sentences(text, max_sentences=2)
    #text = anti_echo(text, question)
    text = polish_spanish(text)
    text = capitalize_spanish(text)

    st.session_state["last_response"] = text
    #st.rerun()

    
    return text

#-------------------------------------------------------------------------
# Rule overrides
#-------------------------------------------------------------------------

def rule_intent_override(user_text: str, predicted_label: str) -> str:
    n = normalize_for_route(user_text)
    if re.fullmatch(r"(motivame|motiva|animame|animo|ayudame|que tranza|qué tranza|que tranza)", n):
        return "social"
    return predicted_label

#-------------------------------------------------------------------------
# Router
#-------------------------------------------------------------------------

def contextual_asnwer(question, label_classes, context_model, cont_tok,
                      tec_model, tec_tok, soc_model, soc_tok, device, gen_params=None, block_web=True):
    context = classify_context(question, label_classes, context_model, cont_tok, device)
    context = rule_intent_override(question, context)

    context_icons = {
        "social": "💬", "modelos": "🔧", "evaluación": "📏", "optimización": "⚙️",
        "visualización": "📈", "aprendizaje": "🧠", "vida digital": "🧑‍💻",
        "estadística": "📊", "infraestructura": "🖥", "datos": "📂", "transformación digital": "🌀"}
    icon = context_icons.get(context, "🧠")

    if gen_params and "seed" in gen_params:
        set_seeds(gen_params["seed"])

    if context == "social":
        return social_asnwer(question, soc_model, soc_tok, device, gen_params=gen_params, block_web=block_web), context
    else:
        return technical_asnwer(question, context, tec_model, tec_tok, device, gen_params=gen_params), context

#***************************************************************************
# MAIN
#***************************************************************************

if __name__ == '__main__':

    # --- Estado que debe persistir en todos los reruns ---
    ss = st.session_state
    ss.setdefault("historial", [])
    ss.setdefault("last_prompt", "")
    ss.setdefault("last_response", "")
    ss.setdefault("just_generated", False)
    
    # Sidebar (control total)
    GEN_PARAMS = sidebar_params()
    GEN_PARAMS["persona"] = st.session_state.persona  # por si acaso

    # Setting historial for the current user
    #if "historial" not in st.session_state:
    #    st.session_state.historial = []

    # Assigning a new ID to the current user
    if "user_id" not in st.session_state:
        st.session_state["user_id"] = str(uuid.uuid4())[:8]

    # Loading classifier encoder classes:
    labels_path = hf_hub_download(repo_id="tecuhtli/assistant-classifier-bert", filename="context_labels.pkl", use_auth_token=HF_TOKEN)
    label_classes = joblib.load(labels_path)

    # Loading Saved Models  
    # Modelo Contexto
    context_model = AutoModelForSequenceClassification.from_pretrained("tecuhtli/assistant-classifier-bert", use_auth_token=HF_TOKEN)    
    cont_tok = AutoTokenizer.from_pretrained("tecuhtli/assistant-classifier-bert", use_auth_token=HF_TOKEN)
    
    # Modelo Técnico
    tec_tok = AutoTokenizer.from_pretrained("tecuhtli/assistant-technical-t5", use_auth_token=HF_TOKEN) 
    tec_model = AutoModelForSeq2SeqLM.from_pretrained("tecuhtli/assistant-technical-t5", use_auth_token=HF_TOKEN) 

    # Modelo Social
    soc_tok = AutoTokenizer.from_pretrained("tecuhtli/assistant-social-t5", use_auth_token=HF_TOKEN) 
    soc_model = AutoModelForSeq2SeqLM.from_pretrained("tecuhtli/assistant-social-t5", use_auth_token=HF_TOKEN) 

    # Available Device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Defining Assistant Presentation
    st.title("🤖 Your Personal Assistant 🎓")

    st.caption("🙋🏽‍ You can ask me about technical concepts such as visualization, data cleaning, BI, and more.")
    st.caption("🙇🏽 I can *only* understand and answer in Spanish (🦅🇲🇽🌵).")
    st.caption("➡️ At this stage, I can respond to simple questions such as:")
    st.caption("   • ¿Cómo estás?   • ¿Qué es...?   • Explícame algo   • Define algo   • ¿Para qué sirve...?")

    st.caption("😊 If you want to know me better, visit: [hazutecuhtli.github.io](https://github.com/hazutecuhtli/LLMs_FineTuned_Chatbot)")

    st.markdown("
", unsafe_allow_html=True) st.caption("✏️ Type **'salir'** to exit.") # 🔁 Limpieza segura antes del formulario if st.session_state.pop("_clear_entrada", False): if "entrada" in st.session_state: del st.session_state["entrada"] # 🧠 Flash de respuesta (la guardamos, pero la mostraremos después del form) _flash = st.session_state.pop("_flash_response", None) with st.form("formulario_assistant"): user_question = st.text_area("📝 Escribe tu pregunta aquí", key="entrada", height=100) submitted = st.form_submit_button("Responder") if submitted: if not user_question: st.info("Chatbot: ¿Podrías repetir eso? No entendí bien 😅") else: response, context = contextual_asnwer( user_question, label_classes, context_model, cont_tok, tec_model, tec_tok, soc_model, soc_tok, device, gen_params=GEN_PARAMS, block_web=True, ) # 🧠 Guarda historial hora_actual = dt.datetime.now().strftime("%Y-%m-%d %H:%M:%S") st.session_state.historial.append(("Tú", user_question, hora_actual)) hora_actual = dt.datetime.now().strftime("%Y-%m-%d %H:%M:%S") st.session_state.historial.append(("Assistant", response, hora_actual)) # 💾 Guarda conversación saving_interaction(user_question, response, context, st.session_state["user_id"]) # 🟩 Guarda respuesta para mostrar después del rerun st.session_state["_flash_response"] = response # 🧼 Limpieza del textarea en el próximo ciclo st.session_state["_clear_entrada"] = True # ♻️ Forzar refresh (sidebar verá el nuevo prompt) st.rerun() # ----------------------------------------------------------- # 💬 Mostrar la respuesta actual (flash) justo aquí ↓↓↓ # ----------------------------------------------------------- if _flash: st.success(_flash) # Mostrar último mensaje (opcional, arriba de todo) #if st.session_state.get("just_generated"): # if st.session_state["last_response"]: # st.success(st.session_state["last_response"]) # st.session_state["just_generated"] = False # ... formulario y lógica de respuesta ... # 🔁 Historial con estilo chat y contenedor con scroll if st.session_state.historial: st.markdown("---") # 💾 Botón de descarga arriba del historial lineas = [] for msg in reversed(st.session_state.historial): if len(msg) == 3: autor, texto, hora = msg lineas.append(f"[{hora}] {autor}: {texto}") else: autor, texto = msg lineas.append(f"{autor}: {texto}") texto_chat = "\n\n".join(lineas) st.download_button( label="💾 Descargar conversación como .txt", data=texto_chat, file_name="conversacion_assistant.txt", mime="text/plain", use_container_width=True ) # 🪟 Contenedor con scroll y burbujas st.markdown( """
""", unsafe_allow_html=True ) for msg in reversed(st.session_state.historial): if len(msg) == 3: autor, texto, _ = msg else: autor, texto = msg if autor == "Tú": st.markdown( f"""
🧍‍♂️ {autor}: {texto}
""", unsafe_allow_html=True ) else: st.markdown( f"""
🤖 {autor}: {texto}
""", unsafe_allow_html=True ) st.markdown("
", unsafe_allow_html=True) #*************************************************************************** # FIN #***************************************************************************