Spaces:
Paused
Paused
| import streamlit as st | |
| from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM, AutoConfig, AutoModelForSequenceClassification | |
| from langchain_community.llms import HuggingFacePipeline | |
| from langchain.prompts import PromptTemplate | |
| from langchain.chains import LLMChain | |
| from langchain_community.embeddings import HuggingFaceEmbeddings | |
| from PyPDF2 import PdfReader | |
| from docx import Document | |
| import csv | |
| import json | |
| import torch | |
| from langchain_community.vectorstores import FAISS | |
| from langchain.text_splitter import RecursiveCharacterTextSplitter | |
| from huggingface_hub import login | |
| # Autenticaci贸n en Hugging Face | |
| huggingface_token = st.secrets["HUGGINGFACE_TOKEN"] | |
| login(huggingface_token) | |
| # Configurar modelo y tokenizador | |
| model_name = 'Qwen/Qwen2-1.5B' | |
| model_config = AutoConfig.from_pretrained(model_name) | |
| tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) | |
| tokenizer.pad_token = tokenizer.eos_token | |
| tokenizer.padding_side = "right" | |
| text_generation_pipeline = pipeline( | |
| model=model_name, | |
| tokenizer=tokenizer, | |
| task="text-generation", | |
| temperature=0.2, | |
| repetition_penalty=1.1, | |
| return_full_text=True, | |
| max_new_tokens=1000, | |
| ) | |
| prompt_template = """ | |
| ### [INST] | |
| Instruction: Answer the question based on your knowledge. Here is context to help: | |
| {context} | |
| ### QUESTION: | |
| {question} | |
| [/INST] | |
| """ | |
| mistral_llm = HuggingFacePipeline(pipeline=text_generation_pipeline) | |
| # Crear el prompt desde la plantilla de prompt | |
| prompt = PromptTemplate( | |
| input_variables=["context", "question"], | |
| template=prompt_template, | |
| ) | |
| # Crear la cadena LLM | |
| llm_chain = LLMChain(llm=mistral_llm, prompt=prompt) | |
| # Funci贸n para manejar archivos subidos | |
| def handle_uploaded_file(uploaded_file): | |
| try: | |
| if uploaded_file.name.endswith(".txt"): | |
| text = uploaded_file.read().decode("utf-8") | |
| elif uploaded_file.name.endswith(".pdf"): | |
| reader = PdfReader(uploaded_file) | |
| text = "" | |
| for page in range(len(reader.pages)): | |
| text += reader.pages[page].extract_text() | |
| elif uploaded_file.name.endswith(".docx"): | |
| doc = Document(uploaded_file) | |
| text = "\n".join([para.text for para in doc.paragraphs]) | |
| elif uploaded_file.name.endswith(".csv"): | |
| text = "" | |
| content = uploaded_file.read().decode("utf-8").splitlines() | |
| reader = csv.reader(content) | |
| text = " ".join([" ".join(row) for row in reader]) | |
| elif uploaded_file.name.endswith(".json"): | |
| data = json.load(uploaded_file) | |
| text = json.dumps(data, indent=4) | |
| else: | |
| text = "Tipo de archivo no soportado." | |
| return text | |
| except Exception as e: | |
| return str(e) | |
| # Funci贸n para traducir texto | |
| def translate(text, target_language): | |
| context = "" | |
| question = f"Por favor, traduzca el siguiente documento al {target_language}:\n{text}\nAseg煤rese de que la traducci贸n sea precisa y conserve el significado original del documento." | |
| response = llm_chain.run(context=context, question=question) | |
| return response | |
| # Funci贸n para resumir texto | |
| def summarize(text, length): | |
| context = "" | |
| question = f"Por favor, haga un resumen {length} del siguiente documento:\n{text}\nAseg煤rese de que el resumen sea conciso y conserve el significado original del documento." | |
| response = llm_chain.run(context=context, question=question) | |
| return response | |
| # Configuraci贸n del modelo de clasificaci贸n | |
| def load_classification_model(): | |
| tokenizer_cls = AutoTokenizer.from_pretrained("mrm8488/legal-longformer-base-8192-spanish") | |
| model_cls = AutoModelForSequenceClassification.from_pretrained("mrm8488/legal-longformer-base-8192-spanish") | |
| return model_cls, tokenizer_cls | |
| classification_model, classification_tokenizer = load_classification_model() | |
| id2label = {0: "multas", 1: "politicas_de_privacidad", 2: "contratos", 3: "denuncias", 4: "otros"} | |
| def classify_text(text): | |
| inputs = classification_tokenizer(text, return_tensors="pt", max_length=4096, truncation=True, padding="max_length") | |
| classification_model.eval() | |
| with torch.no_grad(): | |
| outputs = classification_model(**inputs) | |
| logits = outputs.logits | |
| predicted_class_id = logits.argmax(dim=-1).item() | |
| predicted_label = id2label[predicted_class_id] | |
| return predicted_label | |
| # Funci贸n para cargar documentos JSON | |
| def load_json_documents(category): | |
| try: | |
| with open(f"./{category}.json", "r", encoding="utf-8") as f: | |
| data = json.load(f)["questions_and_answers"] | |
| documents = [entry["question"] + " " + entry["answer"] for entry in data] | |
| return documents | |
| except FileNotFoundError: | |
| return [] | |
| # Configuraci贸n de FAISS y embeddings | |
| def create_vector_store(docs): | |
| embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-l6-v2", model_kwargs={"device": "cpu"}) | |
| text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=150) | |
| split_docs = text_splitter.split_text(docs) | |
| vector_store = FAISS.from_texts(split_docs, embeddings) | |
| return vector_store | |
| def explain_text(user_input, document_text): | |
| classification = classify_text(document_text) | |
| if classification in ["multas", "politicas_de_privacidad", "contratos", "denuncias"]: | |
| docs = load_json_documents(classification) | |
| if docs: | |
| vector_store = create_vector_store(docs) | |
| search_docs = vector_store.similarity_search(user_input) | |
| context = " ".join([doc.page_content for doc in search_docs]) | |
| else: | |
| context = "" | |
| else: | |
| context = "" | |
| question = user_input | |
| response = llm_chain.run(context=context, question=question) | |
| return response | |
| def main(): | |
| st.title("LexAIcon") | |
| st.write("Puedes conversar con este chatbot basado en Mistral-7B-Instruct y subir archivos para que el chatbot los procese.") | |
| with st.sidebar: | |
| st.caption("[Consigue un HuggingFace Token](https://huggingface.co/settings/tokens)") | |
| operation = st.radio("Selecciona una operaci贸n", ["Resumir", "Traducir", "Explicar"]) | |
| if operation == "Explicar": | |
| user_input = st.text_area("Introduce tu pregunta:", "") | |
| uploaded_file = st.file_uploader("Sube un archivo", type=["txt", "pdf", "docx", "csv", "json"]) | |
| if uploaded_file and user_input: | |
| document_text = handle_uploaded_file(uploaded_file) | |
| bot_response = explain_text(user_input, document_text) | |
| st.write(f"**Assistant:** {bot_response}") | |
| else: | |
| uploaded_file = st.file_uploader("Sube un archivo", type=["txt", "pdf", "docx", "csv", "json"]) | |
| if uploaded_file: | |
| document_text = handle_uploaded_file(uploaded_file) | |
| if operation == "Traducir": | |
| target_language = st.selectbox("Selecciona el idioma de traducci贸n", ["espa帽ol", "ingl茅s", "franc茅s", "alem谩n"]) | |
| if target_language: | |
| bot_response = translate(document_text, target_language) | |
| st.write(f"**Assistant:** {bot_response}") | |
| elif operation == "Resumir": | |
| summary_length = st.selectbox("Selecciona la longitud del resumen", ["corto", "medio", "largo"]) | |
| if summary_length: | |
| if summary_length == "corto": | |
| length = "de aproximadamente 50 palabras" | |
| elif summary_length == "medio": | |
| length = "de aproximadamente 100 palabras" | |
| elif summary_length == "largo": | |
| length = "de aproximadamente 500 palabras" | |
| bot_response = summarize(document_text, length) | |
| st.write(f"**Assistant:** {bot_response}") | |
| if __name__ == "__main__": | |
| main() |