|
|
import os |
|
|
import tempfile |
|
|
from langchain_huggingface import HuggingFaceEmbeddings |
|
|
from langchain.text_splitter import RecursiveCharacterTextSplitter |
|
|
from langchain_community.vectorstores import Chroma |
|
|
from langchain_community.document_loaders import TextLoader |
|
|
from langchain_huggingface import HuggingFacePipeline |
|
|
from langchain.chains import RetrievalQA |
|
|
from langchain.prompts import PromptTemplate |
|
|
from langchain.callbacks.base import BaseCallbackHandler |
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline, TextStreamer |
|
|
import streamlit as st |
|
|
from typing import Any, Dict, List |
|
|
|
|
|
|
|
|
os.environ["HF_HOME"] = "/tmp/huggingface_cache" |
|
|
os.environ["TRANSFORMERS_CACHE"] = "/tmp/transformers_cache" |
|
|
os.environ["HF_HUB_CACHE"] = "/tmp/hf_hub_cache" |
|
|
os.environ["SENTENCE_TRANSFORMERS_HOME"] = "/tmp/sentence_transformers_cache" |
|
|
|
|
|
|
|
|
for cache_dir in ["/tmp/huggingface_cache", "/tmp/transformers_cache", "/tmp/hf_hub_cache", "/tmp/sentence_transformers_cache"]: |
|
|
os.makedirs(cache_dir, exist_ok=True) |
|
|
|
|
|
class StreamingCallbackHandler(BaseCallbackHandler): |
|
|
"""Callback handler for streaming responses.""" |
|
|
|
|
|
def __init__(self, placeholder): |
|
|
self.placeholder = placeholder |
|
|
self.text = "" |
|
|
|
|
|
def on_llm_new_token(self, token: str, **kwargs: Any) -> None: |
|
|
"""Handle new token from LLM.""" |
|
|
self.text += token |
|
|
self.placeholder.markdown(self.text + "▌") |
|
|
|
|
|
def on_llm_end(self, response: Any, **kwargs: Any) -> None: |
|
|
"""Handle end of LLM response.""" |
|
|
self.placeholder.markdown(self.text) |
|
|
|
|
|
def load_documents(file_path: str): |
|
|
"""Loads documents from a specified file path.""" |
|
|
loader = TextLoader(file_path) |
|
|
return loader.load() |
|
|
|
|
|
def split_documents(documents, chunk_size=300, chunk_overlap=50): |
|
|
"""Splits documents into smaller, more focused chunks.""" |
|
|
splitter = RecursiveCharacterTextSplitter( |
|
|
chunk_size=chunk_size, |
|
|
chunk_overlap=chunk_overlap, |
|
|
separators=["\n\n", "\n", ".", "!", "?", ",", " ", ""] |
|
|
) |
|
|
return splitter.split_documents(documents) |
|
|
|
|
|
def create_embeddings(model_name="sentence-transformers/all-MiniLM-L6-v2"): |
|
|
"""Creates HuggingFace embeddings with proper cache handling.""" |
|
|
try: |
|
|
|
|
|
embeddings = HuggingFaceEmbeddings( |
|
|
model_name=model_name, |
|
|
cache_folder="/tmp/sentence_transformers_cache" |
|
|
) |
|
|
return embeddings |
|
|
except Exception as e: |
|
|
print(f"Error creating embeddings with {model_name}: {e}") |
|
|
|
|
|
try: |
|
|
print("Trying fallback model: sentence-transformers/paraphrase-MiniLM-L6-v2") |
|
|
embeddings = HuggingFaceEmbeddings( |
|
|
model_name="sentence-transformers/paraphrase-MiniLM-L6-v2", |
|
|
cache_folder="/tmp/sentence_transformers_cache" |
|
|
) |
|
|
return embeddings |
|
|
except Exception as e2: |
|
|
print(f"Fallback model also failed: {e2}") |
|
|
raise e2 |
|
|
|
|
|
def setup_vector_store(docs, embeddings, persist_directory="./chroma_db"): |
|
|
"""Sets up and persists the Chroma vector store.""" |
|
|
db = Chroma.from_documents(docs, embeddings, persist_directory=persist_directory) |
|
|
|
|
|
return db.as_retriever(search_kwargs={"k": 3}) |
|
|
|
|
|
def create_qa_chain(retriever, model_name="Sakalti/Qwen2.5-1B-Instruct"): |
|
|
"""Creates an enhanced QA chain with better prompting and streaming capabilities.""" |
|
|
try: |
|
|
tokenizer = AutoTokenizer.from_pretrained( |
|
|
model_name, |
|
|
cache_dir="/tmp/transformers_cache", |
|
|
trust_remote_code=True |
|
|
) |
|
|
|
|
|
|
|
|
if tokenizer.pad_token is None: |
|
|
tokenizer.pad_token = tokenizer.eos_token |
|
|
|
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
|
model_name, |
|
|
cache_dir="/tmp/transformers_cache", |
|
|
device_map="auto", |
|
|
trust_remote_code=True, |
|
|
torch_dtype="auto" |
|
|
) |
|
|
|
|
|
|
|
|
pipe = pipeline( |
|
|
"text-generation", |
|
|
model=model, |
|
|
tokenizer=tokenizer, |
|
|
max_new_tokens=150, |
|
|
temperature=0.3, |
|
|
top_p=0.9, |
|
|
top_k=50, |
|
|
repetition_penalty=1.3, |
|
|
do_sample=True, |
|
|
pad_token_id=tokenizer.eos_token_id, |
|
|
eos_token_id=tokenizer.eos_token_id, |
|
|
return_full_text=False |
|
|
) |
|
|
|
|
|
llm = HuggingFacePipeline(pipeline=pipe) |
|
|
|
|
|
|
|
|
prompt_template = """CONTEXT INFORMATION: |
|
|
{context} |
|
|
|
|
|
QUESTION: {question} |
|
|
|
|
|
INSTRUCTIONS: Answer the question using ONLY the information provided in the context above. If the answer is not in the context, say "I don't have that information in the provided context." |
|
|
|
|
|
ANSWER:""" |
|
|
|
|
|
prompt = PromptTemplate( |
|
|
template=prompt_template, |
|
|
input_variables=["context", "question"] |
|
|
) |
|
|
|
|
|
qa_chain = RetrievalQA.from_chain_type( |
|
|
llm=llm, |
|
|
retriever=retriever, |
|
|
chain_type="stuff", |
|
|
return_source_documents=True, |
|
|
chain_type_kwargs={"prompt": prompt} |
|
|
) |
|
|
return qa_chain |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Error loading model {model_name}: {e}") |
|
|
return None |
|
|
|
|
|
def create_streaming_response(qa_chain, question: str, placeholder): |
|
|
"""Create a streaming response using the QA chain.""" |
|
|
try: |
|
|
|
|
|
result = qa_chain.invoke({"query": question}) |
|
|
|
|
|
|
|
|
if "source_documents" in result: |
|
|
print("=== RETRIEVED CONTEXT ===") |
|
|
for i, doc in enumerate(result["source_documents"]): |
|
|
print(f"Document {i+1}: {doc.page_content[:200]}...") |
|
|
print("=== END CONTEXT ===") |
|
|
|
|
|
|
|
|
answer = result.get("result", "") |
|
|
|
|
|
|
|
|
answer = clean_response(answer) |
|
|
|
|
|
|
|
|
if len(answer) < 20 or "I don't know" in answer or "cannot answer" in answer: |
|
|
|
|
|
if "source_documents" in result and result["source_documents"]: |
|
|
context_text = " ".join([doc.page_content for doc in result["source_documents"][:2]]) |
|
|
answer = f"Based on the information I have: {context_text[:300]}..." |
|
|
|
|
|
|
|
|
import time |
|
|
displayed_text = "" |
|
|
|
|
|
for i, char in enumerate(answer): |
|
|
displayed_text += char |
|
|
placeholder.markdown(displayed_text + "▌") |
|
|
|
|
|
|
|
|
if i % 3 == 0: |
|
|
time.sleep(0.02) |
|
|
|
|
|
|
|
|
placeholder.markdown(displayed_text) |
|
|
|
|
|
return displayed_text |
|
|
|
|
|
except Exception as e: |
|
|
placeholder.error(f"Error generating response: {e}") |
|
|
return "I apologize, but I encountered an error while processing your question." |
|
|
|
|
|
def clean_response(text: str) -> str: |
|
|
"""Clean up the response to remove repetition and improve quality.""" |
|
|
if not text: |
|
|
return "I couldn't find relevant information to answer your question." |
|
|
|
|
|
|
|
|
if "ANSWER:" in text: |
|
|
text = text.split("ANSWER:", 1)[-1].strip() |
|
|
|
|
|
|
|
|
prompt_artifacts = [ |
|
|
"CONTEXT INFORMATION:", |
|
|
"QUESTION:", |
|
|
"INSTRUCTIONS:", |
|
|
"Based on the context provided,", |
|
|
"According to the document,", |
|
|
"The document states that", |
|
|
"From the information given," |
|
|
] |
|
|
|
|
|
for artifact in prompt_artifacts: |
|
|
if artifact in text: |
|
|
text = text.split(artifact, 1)[-1].strip() |
|
|
|
|
|
|
|
|
sentences = text.split('.') |
|
|
cleaned_sentences = [] |
|
|
|
|
|
for sentence in sentences: |
|
|
sentence = sentence.strip() |
|
|
if sentence and len(sentence) > 5: |
|
|
|
|
|
is_repetitive = False |
|
|
for recent in cleaned_sentences[-2:]: |
|
|
if len(set(sentence.split()) & set(recent.split())) > len(sentence.split()) * 0.7: |
|
|
is_repetitive = True |
|
|
break |
|
|
|
|
|
if not is_repetitive: |
|
|
cleaned_sentences.append(sentence) |
|
|
|
|
|
|
|
|
result = '. '.join(cleaned_sentences) |
|
|
|
|
|
|
|
|
if result and not result.endswith('.'): |
|
|
result += '.' |
|
|
|
|
|
|
|
|
if len(result) > 400: |
|
|
|
|
|
sentences = result[:400].split('.') |
|
|
result = '. '.join(sentences[:-1]) + '.' |
|
|
|
|
|
return result if result.strip() else "I couldn't generate a proper response. Please try rephrasing your question." |
|
|
|
|
|
|