Spaces:
Runtime error
Runtime error
| import os | |
| import torch | |
| import logging | |
| logging.basicConfig(level=logging.DEBUG) | |
| logger = logging.getLogger(__name__) | |
| from langchain.chains import RetrievalQA | |
| from langchain_community.embeddings import HuggingFaceEmbeddings | |
| from langchain_community.document_loaders import PyPDFLoader | |
| from langchain.text_splitter import RecursiveCharacterTextSplitter | |
| from langchain_community.vectorstores import Chroma | |
| from langchain_community.llms import HuggingFaceHub | |
| DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu" | |
| conversation_retrieval_chain = None | |
| chat_history = [] | |
| llm_hub = None | |
| embeddings = None | |
| def init_llm(): | |
| global llm_hub, embeddings | |
| logger.info("Initializing HuggingFace LLM and embeddings...") | |
| # Set Hugging Face API Token | |
| os.environ["HUGGINGFACEHUB_API_TOKEN"] = "YOUR_HF_TOKEN" | |
| # Model from HuggingFace | |
| model_id = "mistralai/Mistral-7B-Instruct-v0.3" | |
| llm_hub = HuggingFaceHub( | |
| repo_id=model_id, | |
| task="text-generation", | |
| model_kwargs={ | |
| "temperature": 0.1, | |
| "max_new_tokens": 256 | |
| } | |
| ) | |
| logger.debug("HuggingFace LLM initialized") | |
| embeddings = HuggingFaceEmbeddings( | |
| model_name="sentence-transformers/all-MiniLM-L6-v2", | |
| model_kwargs={"device": DEVICE} | |
| ) | |
| logger.debug("Embeddings initialized with device %s", DEVICE) | |
| def process_document(document_path): | |
| global conversation_retrieval_chain | |
| logger.info("Loading document from path: %s", document_path) | |
| loader = PyPDFLoader(document_path) | |
| documents = loader.load() | |
| logger.debug("Loaded %d documents", len(documents)) | |
| text_splitter = RecursiveCharacterTextSplitter( | |
| chunk_size=1024, | |
| chunk_overlap=64 | |
| ) | |
| texts = text_splitter.split_documents(documents) | |
| logger.debug("Split into %d chunks", len(texts)) | |
| db = Chroma.from_documents(texts, embedding=embeddings) | |
| conversation_retrieval_chain = RetrievalQA.from_chain_type( | |
| llm=llm_hub, | |
| chain_type="stuff", | |
| retriever=db.as_retriever( | |
| search_type="mmr", | |
| search_kwargs={"k": 6, "lambda_mult": 0.25} | |
| ), | |
| return_source_documents=False, | |
| input_key="question" | |
| ) | |
| logger.info("RetrievalQA chain created") | |
| def process_prompt(prompt): | |
| global conversation_retrieval_chain | |
| global chat_history | |
| logger.info("Processing prompt: %s", prompt) | |
| output = conversation_retrieval_chain.invoke({ | |
| "question": prompt | |
| }) | |
| answer = output["result"] | |
| chat_history.append((prompt, answer)) | |
| logger.debug("Chat history length: %d", len(chat_history)) | |
| return answer | |
| init_llm() | |
| logger.info("LLM initialization complete") |