|
|
import streamlit as st |
|
|
from streamlit_chat import message |
|
|
import pandas as pd |
|
|
import os |
|
|
import json |
|
|
from transformers import pipeline |
|
|
from dotenv import load_dotenv |
|
|
from utils.cno_utils import convert_to_cno, get_cno_description |
|
|
|
|
|
AVATAR_PATH = "https://avatars.githubusercontent.com/u/122880210?s=200&v=4" |
|
|
st.set_page_config( |
|
|
"Clasificador CNO 🤖", "🤖", layout="wide", initial_sidebar_state="expanded" |
|
|
) |
|
|
|
|
|
st.markdown( |
|
|
"""<link href="https://fonts.googleapis.com/css2?family=Inter:wght@300;400;600&display=swap" rel="stylesheet"> |
|
|
<style> |
|
|
html, body, [class*="css"] {font-family: "Inter", sans-serif;} |
|
|
header, footer {visibility: hidden;} |
|
|
.block-container {padding-top: 0.5rem; display: flex; flex-direction: column; min-height: 100vh;} |
|
|
.stChatMessage.user {background: linear-gradient(120deg,#00c6ff 0%,#0072ff 100%); color:#fff; border-radius:1rem; padding:0.75rem 1rem; margin:0.25rem 0;} |
|
|
.stChatMessage.bot {background:#fff; border:1px solid #e0e0e0; border-radius:1rem; padding:0.75rem 1rem; margin:0.25rem 0;} |
|
|
.stChatMessage.bot img {height:24px;width:24px;border-radius:50%;margin-right:0.5rem;} |
|
|
::-webkit-scrollbar {width:8px;} |
|
|
::-webkit-scrollbar-thumb {background:#8f9dff;border-radius:10px;} |
|
|
.sidebar-title {font-size:0.9rem;font-weight:600;margin:0.5rem 0 0.25rem;color:#4b4b4b;} |
|
|
div.msg ul{list-style:none;padding-left:0;margin:0;} |
|
|
[data-testid="stSidebarCollapseButton"] { |
|
|
display: none; |
|
|
} |
|
|
</style>""", |
|
|
unsafe_allow_html=True, |
|
|
) |
|
|
MODEL_ID = "bob-nlp/A5-CNO-ULL-BOB-ISTAC-D12" |
|
|
|
|
|
|
|
|
@st.cache_resource |
|
|
def load_huggingface_model(): |
|
|
"""Carga el pipeline de inferencia desde Hugging Face Hub.""" |
|
|
load_dotenv() |
|
|
hf_token = os.environ.get("HF_TOKEN") |
|
|
if not hf_token: |
|
|
|
|
|
st.error( |
|
|
"HF_TOKEN no encontrado. Por favor, configúralo en los 'Secrets' de tu Space.", |
|
|
icon="🔑", |
|
|
) |
|
|
return None |
|
|
try: |
|
|
model_pipeline = pipeline( |
|
|
"text-classification", |
|
|
model=MODEL_ID, |
|
|
token=hf_token, |
|
|
) |
|
|
return model_pipeline |
|
|
except Exception as e: |
|
|
st.error(f"Error al cargar el modelo '{MODEL_ID}': {e}", icon="🔥") |
|
|
return None |
|
|
|
|
|
|
|
|
def load_json_file(filename): |
|
|
try: |
|
|
with open(filename, "r", encoding="utf-8") as f: |
|
|
return json.load(f) |
|
|
except FileNotFoundError: |
|
|
st.warning( |
|
|
f"El archivo '{filename}' no se ha encontrado en el repositorio del Space." |
|
|
) |
|
|
return {} |
|
|
except Exception as e: |
|
|
st.error(f"Error al leer el archivo JSON '{filename}': {e}") |
|
|
return {} |
|
|
|
|
|
|
|
|
pipe = load_huggingface_model() |
|
|
METADATA = load_json_file("src/data/metadata.json") |
|
|
PROBLEMATIC_CNOS = load_json_file("src/data/problematic_cnos.json") |
|
|
|
|
|
|
|
|
def run_inference(text_input): |
|
|
""" |
|
|
Función que ejecuta la inferencia usando el pipeline de Hugging Face |
|
|
y formatea la salida para mostrarla en la UI. |
|
|
""" |
|
|
if not pipe: |
|
|
return "Error: El modelo no está cargado." |
|
|
|
|
|
try: |
|
|
results = pipe(text_input, top_k=3) |
|
|
PROBLEMATIC_CNO_MESSAGE = " ⚠️⚠️⚠️ **Cuidado: código poco fiable** " |
|
|
out = [] |
|
|
|
|
|
|
|
|
for response in results: |
|
|
response["label"] = convert_to_cno(response["label"]) |
|
|
response["description"] = get_cno_description(response["label"]) |
|
|
main_msg = ( |
|
|
f"Predicción: **{response['label']}**: {response['description']} " |
|
|
f"Certeza: **{response['score']:.2f}** " |
|
|
) |
|
|
|
|
|
if response["label"] in PROBLEMATIC_CNOS: |
|
|
main_msg += PROBLEMATIC_CNO_MESSAGE |
|
|
out.append(main_msg) |
|
|
|
|
|
return "\n".join(out) |
|
|
|
|
|
except Exception as e: |
|
|
return f"Ocurrió un error durante la inferencia: {e}" |
|
|
|
|
|
|
|
|
st.sidebar.title("Clasificador CNO-11") |
|
|
st.sidebar.markdown("---") |
|
|
st.sidebar.markdown( |
|
|
"<div class='sidebar-title'>🎛️ Filtros</div>", unsafe_allow_html=True |
|
|
) |
|
|
|
|
|
|
|
|
def init_state() -> None: |
|
|
defaults = load_json_file("src/data/defaults_session_state.json") |
|
|
for k, v in defaults.items(): |
|
|
st.session_state.setdefault(k, v) |
|
|
|
|
|
|
|
|
init_state() |
|
|
|
|
|
|
|
|
def on_controls_change() -> None: |
|
|
st.session_state.past.clear() |
|
|
st.session_state.generated.clear() |
|
|
|
|
|
|
|
|
for col, metadatas in METADATA.items(): |
|
|
sel = st.sidebar.selectbox( |
|
|
label=col, |
|
|
options=list(metadatas), |
|
|
key=f"select_{col}", |
|
|
on_change=on_controls_change, |
|
|
format_func=lambda x: x["textual"], |
|
|
) |
|
|
st.session_state.selections[col] = sel |
|
|
|
|
|
st.sidebar.divider() |
|
|
if st.sidebar.button("🗑️ Limpiar conversación", use_container_width=True): |
|
|
on_controls_change() |
|
|
st.rerun() |
|
|
|
|
|
st.title("🤖 Clasificador CNO-11 ULL BOB-ISTAC") |
|
|
st.info( |
|
|
f"Utilizando el modelo: **[{MODEL_ID}](https://huggingface.co/bob-nlp/A5-CNO-BOB-ISTAC-D12)**" |
|
|
) |
|
|
|
|
|
|
|
|
def add_user_message(text: str): |
|
|
st.session_state.past.append(text) |
|
|
to_classify = f"{text}." |
|
|
if st.session_state.selections: |
|
|
for col, sel in st.session_state.selections.items(): |
|
|
if sel: |
|
|
to_classify += f" {sel['textual']}." |
|
|
response = run_inference(to_classify) |
|
|
st.session_state.generated.append((response)) |
|
|
|
|
|
|
|
|
def render_chat(): |
|
|
message( |
|
|
"¡Hola! Soy el clasificador de códigos CNO-11. Por favor, introduce una descripción de la tarea o ocupación que quieres clasificar y te ayudaré a encontrar el código CNO correspondiente.", |
|
|
is_user=False, |
|
|
key="welcome", |
|
|
logo=AVATAR_PATH, |
|
|
) |
|
|
|
|
|
for i, (u, b) in enumerate(zip(st.session_state.past, st.session_state.generated)): |
|
|
message(u, is_user=True, key=f"u{i}", avatar_style="no-avatar") |
|
|
message(b, key=f"b{i}", logo=AVATAR_PATH) |
|
|
|
|
|
|
|
|
chat_box = st.container() |
|
|
with chat_box: |
|
|
st.markdown('<div class="chat-container">', unsafe_allow_html=True) |
|
|
render_chat() |
|
|
new_text = st.chat_input("Escribe aquí el texto a clasificar…") |
|
|
if new_text: |
|
|
add_user_message(new_text) |
|
|
st.rerun() |
|
|
st.markdown("</div>", unsafe_allow_html=True) |
|
|
|
|
|
|
|
|
st.markdown("---") |
|
|
st.markdown( |
|
|
""" |
|
|
<div style="text-align: center; color: #666; font-size: 0.8rem; padding: 1rem 0;"> |
|
|
Desarrollado por <strong>Cátedra Cajasiete de Big Data, Open Data y Blockchain</strong><br> |
|
|
Universidad de La Laguna |
|
|
</div> |
|
|
""", |
|
|
unsafe_allow_html=True, |
|
|
) |
|
|
|