Mori_Bot / app.py
tecuhtli's picture
Update app.py
0cd97d1 verified
raw
history blame
13.1 kB
#***************************************************************************
#Importing Libraries
#***************************************************************************
import os, sys, torch, json, csv, warnings, joblib, uuid
os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0"
import streamlit as st
from pathlib import Path
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForSequenceClassification
from unidecode import unidecode
from datetime import datetime
from huggingface_hub import hf_hub_download, login
#***************************************************************************
#Defining default paths for the model to work
#***************************************************************************
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#***************************************************************************
#Setting up variables
#***************************************************************************
# Token privado desde variable de entorno
HF_TOKEN = os.environ.get("HF_TOKEN")
#***************************************************************************
#Functions
#***************************************************************************
# Function to clean the question field
def limpiar_input():
st.session_state["entrada"] = ""
# Function to save user interaction
def saving_interaction(question, response, context, user_id):
'''
inputs:
question --> User input question
response --> Mori response to the user question
context --> Context related to the user input, found by the trained classifier
user_id --> ID for the current user (Unique ID per session)
'''
# Getting the current time for the saving log
timestamp = datetime.now().isoformat()
# Defining the path to save the current interaction
stats_dir = Path("Statistics")
stats_dir.mkdir(parents=True, exist_ok=True)
# Setting the file to save the interactions
archivo_csv = stats_dir / "conversaciones_log.csv"
existe_csv = archivo_csv.exists()
# Saving statistics as a csv
with open(archivo_csv, mode="a", encoding="utf-8", newline="") as f_csv:
writer = csv.writer(f_csv)
if not existe_csv:
writer.writerow(["timestamp", "user_id", "contexto", "pregunta", "respuesta"])
writer.writerow([timestamp, user_id, context, question, response])
# Saving statiistics as a json file
archivo_jsonl = stats_dir / "conversaciones_log.jsonl"
with open(archivo_jsonl, mode="a", encoding="utf-8") as f_jsonl:
registro = {
"timestamp": timestamp,
"user_id": user_id,
"context": context,
"pregunta": question,
"respuesta": response}
f_jsonl.write(json.dumps(registro, ensure_ascii=False) + "\n")
# Function to load models within the huggingface respositories space
@st.cache_resource
def load_model(path_str):
'''
inputs:
path_str --> Path for loading models and tokenizers
outsputs:
model --> Loaded Model
tokenizer --> Loaded tokenizer
'''
path = Path(path_str).resolve()
tokenizer = AutoTokenizer.from_pretrained(path, local_files_only=True)
model = AutoModelForSeq2SeqLM.from_pretrained(path, local_files_only=True)
return model, tokenizer
# Funcion para clasificar las preguntas del usuario definiendo el contexto de las mismas
def classify_context(question, label_classes, model, tokenizer, device):
'''
inputs:
question --> Pregunta formulada por el usuario
label_classes --> Clases del label encoder para decodificar inferencias
model --> Clasificador para determinar el contexto de las pregutnas
tokenizer --> Tokenizer usada para clasificar contextos
device --> Usar el GPU o el CPU dependiendo de su disponibilidad
outsputs:
predicted_label --> Clasificacion de la pregunta en diversos contextos (clases)
'''
# Moviendo el modelo al device disponible
model = model.to(device)
# Procesando la entrada del usuario
inputs = tokenizer(question, return_tensors="pt", padding=True, truncation=True, max_length=128)
inputs = {key: val.to(device) for key, val in inputs.items()}
# Clasificacion de la pregunta del usuario en contextos
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
# Inferencia del clasificador
pred_intent = torch.argmax(logits, dim=1).item()
predicted_label = label_classes[pred_intent]
return predicted_label
# Funcion para generar respuestas tecnicas de Mori
def technical_asnwer(question, context, model, tokenizer, device):
'''
inputs:
question --> Pregunta formulada por el usuario
context --> Contexto de la preguntadel usario definido por el clasificador
model --> Modelo de Mori para responder preguntas tecnicas
tokenizer --> Tokenizer usado para procesar entradas y decoodificar respuestas
device --> Usar el GPU o el CPU dependiendo de su disponibilidad
outsputs:
response --> Respues de Mori tecnico (Modelo tecnico)
'''
# Moviendo el modelo al device disponible
model = model.to(device)
# Promp Engineering para ayudar a Mori a encontrar la mejor respuesta
input_text = f"Context: {context} [SEP] Question: {question}"
# Tokenizando el texto de entrada
inputs = tokenizer(input_text, return_tensors="pt").to(device)
# Generando la respuesta
summary_ids = model.generate(inputs['input_ids'], max_length=150, num_beams=5, early_stopping=True)
# Decodificando la respuesta
response = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
return "🧠 [Mori Técnico] " + response.strip()
# Funcion para generar respuestas sociales de Mori
def social_asnwer(question, model, tokenizer, device):
'''
inputs:
question --> Pregunta formulada por el usuario
model --> Modelo de Mori para responder preguntas sociales
tokenizer --> Tokenizer usado para procesar entradas y decoodificar respuestas
device --> Usar el GPU o el CPU dependiendo de su disponibilidad
outsputs:
response --> Respues de Mori social (Modelo social)
'''
# Moviendo el modelo al device disponible
model = model.to(device)
# Tokenizando la entrada del usuario sin agregar <eos> explícitamente
inputs = tokenizer(
question, # ✅ sin agregar eos_token
return_tensors="pt",
padding=True,
truncation=True,
max_length=128 # ✅ especificado para evitar warning
).to(device)
# Generando respuesta usando muestreo
output_ids = model.generate(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"], # ✅ FIX agregado
max_length=50,
pad_token_id= tokenizer.eos_token_id,
do_sample=True,
top_p=0.95,
top_k=50)
# Decodificando y limpiando la respuesta
response = tokenizer.decode(output_ids[0], skip_special_tokens=True)
return "🤝 [Mori Social] " + response.strip()
# Funcion para generar respuesta general de Mori
def contextual_asnwer(question, label_classes, context_model, cont_tok, tec_model, tec_tok, soc_model, soc_tok, device):
'''
inputs:
question --> Pregunta formulada por el usuario
label_classes --> Clases del label encoder para decodificar inferencias
context_model --> Clasificador para determinar el contexto de las pregutnas
cont_tok --> Tokenizer usada para clasificar contextos
tec_model --> Modelo de Mori para responder preguntas tecnicas
tec_tok --> Tokenizer usado por Mori Tenico
soc_model --> Modelo de Mori para responder preguntas sociales
soc_tok --> Tokenizer usado por Mori Social
device --> Usar el GPU o el CPU dependiendo de su disponibilidad
outsputs:
response --> Respues de Mori General (Respues con Prompt Engineering)
'''
# Detectar contexto usando el clasificador
context = classify_context(question, label_classes, context_model, cont_tok, device)
context_icons = {"social": "💬",
"modelos": "🔧",
"evaluación": "📏",
"optimización": "⚙️",
"visualización": "📈",
"aprendizaje": "🧠",
"vida digital" : "🧑‍💻",
"estadística": "📊",
"infraestructura": "🖥",
"datos": "📂",
"transformación digital": "🌀"}
icon = context_icons.get(context, "🧠")
#print(f"{icon} Contexto detectado: {context}") # (opcional para debug)
st.markdown(f"**{icon} Contexto detectado:** `{context}`")
if context == 'social':
# Generar respuesta contextual usando el modelo social
response = social_asnwer(question, soc_model,soc_tok, device)
else:
# Generar respuesta contextual usando el modelo tecnico
response = technical_asnwer(question, context, tec_model, tec_tok, device)
return response, context
#***************************************************************************
#MAIN
#***************************************************************************
if __name__ == '__main__':
# Setting historial for the current user
if "historial" not in st.session_state:
st.session_state.historial = []
# Addigning a new ID to the current user
if "user_id" not in st.session_state:
st.session_state["user_id"] = str(uuid.uuid4())[:8] # Ej: "f6a9b3e2"
# Loading classifier encoder classes:
labels_path = hf_hub_download(repo_id="tecuhtli/mori-context-model", filename="context_labels.pkl", use_auth_token=HF_TOKEN)
label_classes = joblib.load(labels_path)
# Loading Saved Models
# Modelo Contexto
context_model = AutoModelForSequenceClassification.from_pretrained("tecuhtli/mori-context-model", use_auth_token=HF_TOKEN)
cont_tok = AutoTokenizer.from_pretrained("tecuhtli/mori-context-model", use_auth_token=HF_TOKEN)
# Modelo Técnico
tec_tok = AutoTokenizer.from_pretrained("tecuhtli/mori-tecnico-model", use_auth_token=HF_TOKEN)
tec_model = AutoModelForSeq2SeqLM.from_pretrained("tecuhtli/mori-tecnico-model", use_auth_token=HF_TOKEN)
# Modelo Social
soc_tok = AutoTokenizer.from_pretrained("tecuhtli/mori-social-model", use_auth_token=HF_TOKEN)
soc_model = AutoModelForSeq2SeqLM.from_pretrained("tecuhtli/mori-social-model", use_auth_token=HF_TOKEN)
# Available Device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Defining Moris Presetation
st.title("🤖 Mori - Tu Asistente Personal 🐈")
st.caption("💬 Puedes preguntarme conceptos técnicos como visualización, limpieza, BI, etc.")
st.caption("😅 Por el momento, solo puedo contestar preguntas como: ")
st.caption("🤓 ¿Como estas? ¿Que son?, Explícame algo, Define algo, ¿Para que sirven?")
st.caption("✏️ Escribe 'salir' para terminar.\n")
#entrada_usuario = st.text_area("📝 Escribe tu pregunta aquí")
with st.form("formulario_mori"):
user_question = st.text_area("📝 Escribe tu pregunta aquí", key="entrada", height=100)
submitted = st.form_submit_button("Responder")
if submitted:
if not user_question:
print("Mori: ¿Podrías repetir eso? No entendí bien 😅")
else:
#if st.button("Responder") and entrada_usuario:
response, context = contextual_asnwer(user_question, label_classes, context_model, cont_tok, tec_model, tec_tok, soc_model, soc_tok, device)
st.success(response)
# Guarda en historial
st.session_state.historial.append(("Mori", response))
st.session_state.historial.append(("Tú", user_question))
# 💾 Guarda en archivo para stats/dataset
saving_interaction(user_question, response, context, st.session_state["user_id"])
# 🔁 Muestra historial
if st.session_state.historial:
st.markdown("---")
for autor, texto in reversed(st.session_state.historial):
if autor == "Tú":
st.markdown(f"🧍‍♂️ **{autor}**: {texto}")
else:
st.markdown(f"🤖 **{autor}**: {texto}")
if st.session_state.historial:
texto_chat = ""
for autor, texto in st.session_state.historial:
texto_chat += f"{autor}: {texto}\n\n"
st.download_button(
label="💾 Descargar conversación como .txt",
data=texto_chat,
file_name="conversacion_mori.txt",
mime="text/plain")
#***************************************************************************
#FIN
#***************************************************************************