BeRU Deployer
Deploy BeRU Streamlit RAG System - Add app, models logic, configs, and optimizations for HF Spaces
dec533d | import os | |
| import torch | |
| import logging | |
| import asyncio | |
| import re | |
| from pathlib import Path | |
| from typing import List, Dict, Optional, Any | |
| from contextlib import asynccontextmanager | |
| from logging.handlers import RotatingFileHandler | |
| # --- LANGCHAIN IMPORTS --- | |
| from langchain_community.vectorstores import FAISS | |
| from langchain.chains import create_history_aware_retriever | |
| from langchain.chains.retrieval import create_retrieval_chain | |
| from langchain.chains.combine_documents import create_stuff_documents_chain | |
| from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder | |
| from langchain_community.llms import HuggingFacePipeline | |
| from langchain_core.embeddings import Embeddings | |
| from langchain_core.messages import HumanMessage, AIMessage | |
| from langchain_community.retrievers import BM25Retriever | |
| from langchain.retrievers import EnsembleRetriever | |
| from langchain.retrievers.multi_query import MultiQueryRetriever | |
| from langchain_core.runnables import RunnablePassthrough | |
| from langchain_core.output_parsers import StrOutputParser | |
| from operator import itemgetter | |
| # --- RERANKING IMPORTS --- | |
| # Ensure you have installed flashrank: pip install flashrank | |
| from langchain.retrievers import ContextualCompressionRetriever | |
| from langchain_community.document_compressors import FlashrankRerank | |
| # --- TRANSFORMERS IMPORTS --- | |
| from transformers import ( | |
| AutoTokenizer, | |
| AutoModelForCausalLM, | |
| AutoModel, | |
| pipeline, | |
| BitsAndBytesConfig | |
| ) | |
| # --- FASTAPI IMPORTS --- | |
| from fastapi import FastAPI | |
| from fastapi.responses import HTMLResponse, JSONResponse | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel, Field, field_validator | |
| import uvicorn | |
| import numpy as np | |
| # ------------------------------------------------------------------------- | |
| # 1. Pydantic Patch (Crucial for offline serialization) | |
| # ------------------------------------------------------------------------- | |
| def patch_pydantic_for_pickle(): | |
| try: | |
| from pydantic.v1.main import BaseModel as PydanticV1BaseModel | |
| original_setstate = PydanticV1BaseModel.__setstate__ | |
| def patched_setstate(self, state): | |
| if '__fields_set__' not in state: | |
| state['__fields_set__'] = set(state.get('__dict__', {}).keys()) | |
| if '__private_attribute_values__' not in state: | |
| state['__private_attribute_values__'] = {} | |
| try: | |
| original_setstate(self, state) | |
| except Exception as e: | |
| object.__setattr__(self, '__dict__', state.get('__dict__', {})) | |
| object.__setattr__(self, '__fields_set__', state.get('__fields_set__', set())) | |
| object.__setattr__(self, '__private_attribute_values__', state.get('__private_attribute_values__', {})) | |
| PydanticV1BaseModel.__setstate__ = patched_setstate | |
| print("β Pydantic v1 patched for pickle compatibility") | |
| except ImportError: | |
| try: | |
| import pydantic.v1 as pydantic_v1 | |
| from pydantic.v1 import BaseModel | |
| original_setstate = BaseModel.__setstate__ | |
| def patched_setstate(self, state): | |
| if '__fields_set__' not in state: | |
| state['__fields_set__'] = set(state.get('__dict__', {}).keys()) | |
| if '__private_attribute_values__' not in state: | |
| state['__private_attribute_values__'] = {} | |
| try: | |
| original_setstate(self, state) | |
| except: | |
| object.__setattr__(self, '__dict__', state.get('__dict__', {})) | |
| object.__setattr__(self, '__fields_set__', state.get('__fields_set__', set())) | |
| BaseModel.__setstate__ = patched_setstate | |
| print("β Pydantic patched for pickle compatibility") | |
| except Exception as e: | |
| print(f"β οΈ Could not patch Pydantic: {e}") | |
| patch_pydantic_for_pickle() | |
| # ------------------------------------------------------------------------- | |
| # 2. Configuration & Paths (workspace-agnostic) | |
| # ------------------------------------------------------------------------- | |
| # environment variables allow overrides when running in containers / Spaces | |
| os.environ["HF_HUB_DISABLE_SYMLINKS_WARNING"] = "1" | |
| os.environ["TRANSFORMERS_OFFLINE"] = "1" | |
| os.environ["HF_DATASETS_OFFLINE"] = "1" | |
| os.environ["HF_HUB_OFFLINE"] = "1" | |
| # base directory for application files inside a container | |
| ROOT_DIR = Path(os.environ.get("APP_ROOT", "/app")).resolve() | |
| # model and index locations can be provided via env; defaults point into /app | |
| MODEL_DIR = Path(os.environ.get("MODEL_DIR", ROOT_DIR / "models")) | |
| LLM_MODEL_PATH = Path(os.environ.get("LLM_MODEL_PATH", MODEL_DIR / "Mistral-7B-Instruct-v0.3")) | |
| EMBED_MODEL_PATH = Path(os.environ.get("EMBED_MODEL_PATH", MODEL_DIR / "VLM2Vec-Qwen2VL-2B")) | |
| FAISS_INDEX_PATH = Path(os.environ.get("FAISS_INDEX_PATH", ROOT_DIR / "VLM2Vec-V2rag3")) | |
| # Increased timeout for reranking operations | |
| GENERATION_TIMEOUT = 240 | |
| LLM_MODEL = str(LLM_MODEL_PATH) | |
| EMBED_MODEL = str(EMBED_MODEL_PATH) | |
| # Logging Setup | |
| logger = logging.getLogger("rag_system") | |
| handler = RotatingFileHandler("rag.log", maxBytes=10 * 1024 * 1024, backupCount=5) | |
| formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s') | |
| handler.setFormatter(formatter) | |
| logger.addHandler(handler) | |
| logger.setLevel(logging.INFO) | |
| # Global Variables | |
| vectorstore = None | |
| llm_pipeline = None | |
| qa_chain = None | |
| answer_cache: Dict[str, Dict] = {} | |
| conversations: Dict[str, List[Dict]] = {} | |
| # ------------------------------------------------------------------------- | |
| # 3. VLM2Vec Embedding Class (Preserved) | |
| # ------------------------------------------------------------------------- | |
| class VLM2VecEmbeddings(Embeddings): | |
| def __init__(self, model_path: str, device: str = "cpu"): | |
| print(f"π Loading VLM2Vec model from: {model_path}") | |
| self.device = device | |
| self.model_path = model_path | |
| self.tokenizer = AutoTokenizer.from_pretrained( | |
| model_path, | |
| trust_remote_code=True, | |
| local_files_only=True, | |
| ) | |
| if self.tokenizer.pad_token_id is None and self.tokenizer.eos_token_id is not None: | |
| self.tokenizer.pad_token = self.tokenizer.eos_token | |
| device_map = "auto" if device == "cuda" else "cpu" | |
| dtype = torch.float16 if device == "cuda" else torch.float32 | |
| self.model = AutoModel.from_pretrained( | |
| model_path, | |
| trust_remote_code=True, | |
| dtype=dtype, | |
| device_map=device_map, | |
| local_files_only=True, | |
| ) | |
| self.model.eval() | |
| try: | |
| self.model_device = next(self.model.parameters()).device | |
| except: | |
| self.model_device = torch.device("cuda" if device == "cuda" else "cpu") | |
| with torch.no_grad(): | |
| test_input = self.tokenizer("test", return_tensors="pt", add_special_tokens=True) | |
| test_input = {k: v.to(self.model_device) for k, v in test_input.items()} | |
| out = self.model(**test_input, output_hidden_states=True) | |
| self.embedding_dim = out.hidden_states[-1].shape[-1] | |
| print(f"β VLM2Vec loaded on {self.model_device} | dim={self.embedding_dim}\n") | |
| def _normalize_text(self, text: str) -> str: | |
| text = re.sub(r'\s+', ' ', text or "") | |
| text = re.sub(r'Page \d+', '', text, flags=re.IGNORECASE) | |
| return text.strip() | |
| def _ensure_non_empty(self, text: str) -> str: | |
| t = self._normalize_text(text) | |
| return t if t else "[EMPTY]" | |
| def _embed_single(self, text: str) -> List[float]: | |
| try: | |
| with torch.no_grad(): | |
| clean_text = self._ensure_non_empty(text) | |
| inputs = self.tokenizer( | |
| clean_text, | |
| return_tensors="pt", | |
| add_special_tokens=True, | |
| padding=True, | |
| truncation=True, | |
| max_length=512 | |
| ) | |
| inputs = {k: v.to(self.model_device) for k, v in inputs.items()} | |
| outputs = self.model(**inputs, output_hidden_states=True) | |
| if hasattr(outputs, "hidden_states") and outputs.hidden_states is not None: | |
| hidden_states = outputs.hidden_states[-1] | |
| attention_mask = inputs["attention_mask"].unsqueeze(-1).float() | |
| weighted = hidden_states * attention_mask | |
| sum_embeddings = weighted.sum(dim=1) | |
| sum_mask = torch.clamp(attention_mask.sum(dim=1), min=1e-9) | |
| embedding = (sum_embeddings / sum_mask).squeeze(0) | |
| else: | |
| embedding = outputs.logits.mean(dim=1).squeeze(0) | |
| return embedding.cpu().numpy().tolist() | |
| except Exception as e: | |
| logger.error(f"VLM2Vec embedding error: {e}") | |
| return [0.0] * getattr(self, "embedding_dim", 1024) | |
| def embed_documents(self, texts: List[str]) -> List[List[float]]: | |
| return [self._embed_single(t) for t in texts] | |
| def embed_query(self, text: str) -> List[float]: | |
| return self._embed_single(text) | |
| # ------------------------------------------------------------------------- | |
| # 4. Prompt Templates (CLEANER & STRICTER) | |
| # ------------------------------------------------------------------------- | |
| PROMPT_TEMPLATES = { | |
| "Short and Concise": """<s>[INST] Answer the question based ONLY on the following context. Keep the answer under 3 sentences. | |
| Context: | |
| {context} | |
| Question: | |
| {input} [/INST]""", | |
| "Detailed": """<s>[INST] You are a helpful assistant. Answer the question using ONLY the following context. Provide a detailed summary (4-5 sentences). | |
| Context: | |
| {context} | |
| Question: | |
| {input} [/INST]""", | |
| "Step-by-Step": """<s>[INST] Based on the context below, provide a step-by-step procedure to answer the question. | |
| Context: | |
| {context} | |
| Question: | |
| {input} [/INST]""", | |
| } | |
| def structure_answer(answer: str, style: str) -> str: | |
| # 1. REMOVE "Enough thinking" and specific artifacts | |
| artifacts = [ | |
| "Enough thinking", | |
| "Note:", | |
| "System:", | |
| "User:", | |
| "[/INST]", | |
| "Here is the answer:", | |
| "Answer:" | |
| ] | |
| for artifact in artifacts: | |
| if artifact in answer: | |
| # If it's "Enough thinking", just delete the phrase | |
| answer = answer.replace(artifact, "") | |
| # 2. SPLIT at likely hallucination points | |
| # If the model starts writing "Human:" or "Question:" again, STOP there. | |
| stop_markers = ["Human:", "Question:", "User input:", "Context:"] | |
| for marker in stop_markers: | |
| if marker in answer: | |
| answer = answer.split(marker)[0] | |
| clean_answer = answer.strip() | |
| # 3. Final Formatting | |
| if style == "Short and Concise": | |
| sentences = clean_answer.split('.') | |
| clean_answer = ". ".join(sentences[:2]) + "." | |
| return clean_answer | |
| # ------------------------------------------------------------------------- | |
| # 5. Load System | |
| # ------------------------------------------------------------------------- | |
| def load_system(): | |
| global vectorstore, llm_pipeline, qa_chain | |
| if not os.path.exists(LLM_MODEL_PATH): | |
| raise FileNotFoundError(f"LLM model not found at: {LLM_MODEL_PATH}") | |
| if not os.path.exists(EMBED_MODEL_PATH): | |
| raise FileNotFoundError(f"Embedding model not found at: {EMBED_MODEL_PATH}") | |
| if not os.path.exists(FAISS_INDEX_PATH): | |
| raise FileNotFoundError( | |
| f"FAISS index not found at: {FAISS_INDEX_PATH}\n" | |
| f"Please run the rebuild_faiss_index.py script first!" | |
| ) | |
| print("\n" + "=" * 70) | |
| print("π LOADING RAG SYSTEM: Mistral 7B + VLM2Vec + Reranking (OFFLINE)") | |
| print("=" * 70 + "\n") | |
| _load_vectorstore() | |
| _load_llm() | |
| _build_retrieval_chain() | |
| print("β RAG system ready (100% OFFLINE)!\n") | |
| def _load_embeddings(): | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| embedding_model = VLM2VecEmbeddings( | |
| model_path=EMBED_MODEL_PATH, | |
| device=device, | |
| ) | |
| return embedding_model | |
| def _load_vectorstore(): | |
| global vectorstore | |
| import faiss | |
| import pickle | |
| from langchain_community.docstore.in_memory import InMemoryDocstore | |
| from langchain_core.documents import Document | |
| print(f"π₯ Loading FAISS index from: {FAISS_INDEX_PATH}") | |
| text_index_path = os.path.join(FAISS_INDEX_PATH, "text_index.faiss") | |
| text_docs_path = os.path.join(FAISS_INDEX_PATH, "text_documents.pkl") | |
| if not os.path.exists(text_index_path): | |
| raise FileNotFoundError(f"text_index.faiss not found") | |
| if not os.path.exists(text_docs_path): | |
| raise FileNotFoundError(f"text_documents.pkl not found") | |
| embedding_model = _load_embeddings() | |
| try: | |
| index = faiss.read_index(text_index_path) | |
| print(f" π FAISS index loaded: {index.ntotal} vectors") | |
| print(" π Loading documents...") | |
| documents = None | |
| # Robust loading mechanism | |
| try: | |
| import pickle5 | |
| with open(text_docs_path, 'rb') as f: | |
| documents = pickle5.load(f) | |
| print(" β Loaded with pickle5") | |
| except (ImportError, Exception) as e: | |
| pass | |
| if documents is None: | |
| try: | |
| with open(text_docs_path, 'rb') as f: | |
| documents = pickle.load(f, encoding='latin1') | |
| print(" β Loaded with latin1 encoding") | |
| except Exception as e: | |
| pass | |
| if documents is None: | |
| try: | |
| import dill | |
| with open(text_docs_path, 'rb') as f: | |
| documents = dill.load(f) | |
| print(" β Loaded with dill") | |
| except Exception as e: | |
| print(f" β οΈ dill failed: {e}") | |
| raise RuntimeError("Could not load documents. Check pickle version.") | |
| if isinstance(documents, list): | |
| print(f" Loaded {len(documents)} documents") | |
| reconstructed_docs = [] | |
| for doc in documents: | |
| if isinstance(doc, Document): | |
| reconstructed_docs.append(doc) | |
| else: | |
| try: | |
| new_doc = Document( | |
| page_content=doc.page_content if hasattr(doc, 'page_content') else str(doc), | |
| metadata=doc.metadata if hasattr(doc, 'metadata') else {} | |
| ) | |
| reconstructed_docs.append(new_doc) | |
| except Exception as e: | |
| print(f" β οΈ Could not reconstruct document: {e}") | |
| documents = reconstructed_docs | |
| docstore = InMemoryDocstore({str(i): doc for i, doc in enumerate(documents)}) | |
| index_to_docstore_id = {i: str(i) for i in range(len(documents))} | |
| elif isinstance(documents, dict): | |
| print(f" Loaded {len(documents)} documents (dict)") | |
| docstore = InMemoryDocstore(documents) | |
| index_to_docstore_id = {i: key for i, key in enumerate(documents.keys())} | |
| else: | |
| raise ValueError(f"Unexpected documents format: {type(documents)}") | |
| vectorstore = FAISS( | |
| embedding_function=embedding_model, | |
| index=index, | |
| docstore=docstore, | |
| index_to_docstore_id=index_to_docstore_id | |
| ) | |
| print(f" π Total vectors: {vectorstore.index.ntotal}") | |
| print("β FAISS vectorstore loaded\n") | |
| except Exception as e: | |
| print(f"β Error loading FAISS index: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| raise | |
| def _load_llm(): | |
| print(f"π€ Loading LLM from: {LLM_MODEL_PATH} (OFFLINE - SPEED OPTIMIZED)") | |
| bnb_config = BitsAndBytesConfig( | |
| load_in_4bit=True, | |
| bnb_4bit_quant_type="nf4", | |
| bnb_4bit_compute_dtype=torch.float16, | |
| bnb_4bit_use_double_quant=True, | |
| ) | |
| tokenizer = AutoTokenizer.from_pretrained(LLM_MODEL_PATH, local_files_only=True) | |
| if tokenizer.pad_token_id is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| # CHECK FOR FLASH ATTENTION SUPPORT | |
| # (Fall back to standard if not supported) | |
| try: | |
| model = AutoModelForCausalLM.from_pretrained( | |
| LLM_MODEL_PATH, | |
| quantization_config=bnb_config, | |
| device_map="auto", | |
| local_files_only=True, | |
| attn_implementation="flash_attention_2" # <--- SPEED BOOST | |
| ) | |
| print(" β‘ Flash Attention 2 Enabled!") | |
| except: | |
| print(" β οΈ Flash Attention 2 not supported. Using standard attention.") | |
| model = AutoModelForCausalLM.from_pretrained( | |
| LLM_MODEL_PATH, | |
| quantization_config=bnb_config, | |
| device_map="auto", | |
| local_files_only=True, | |
| ) | |
| pipe = pipeline( | |
| "text-generation", | |
| model=model, | |
| tokenizer=tokenizer, | |
| max_new_tokens=512, | |
| do_sample=True, | |
| temperature=0.01, | |
| top_p=0.95, | |
| pad_token_id=tokenizer.eos_token_id, | |
| return_full_text=False # Stop repetition | |
| ) | |
| global llm_pipeline | |
| llm_pipeline = HuggingFacePipeline(pipeline=pipe) | |
| print("β LLM Loaded\n") | |
| def format_docs_with_sources(docs): | |
| """ | |
| Combines document content with its metadata (Source File & Page). | |
| """ | |
| formatted_entries = [] | |
| for doc in docs: | |
| # Extract metadata (default to 'Unknown' if missing) | |
| source = doc.metadata.get("source", "Unknown Document") | |
| # Optional: Clean the path to just show filename | |
| # source = source.split("\\")[-1] | |
| page = doc.metadata.get("page", "?") | |
| entry = f"--- REFERENCE: {source} (Page {page}) ---\n{doc.page_content}\n" | |
| formatted_entries.append(entry) | |
| return "\n\n".join(formatted_entries) | |
| def _build_retrieval_chain(): | |
| global qa_chain | |
| print("π Building Production RAG Chain (Sources + Hybrid)...") | |
| # --- A. RETRIEVER SETUP (Speed Optimized) --- | |
| # 1. Vector Search | |
| faiss_retriever = vectorstore.as_retriever(search_kwargs={"k": 10}) | |
| # 2. BM25 (Keyword Search) | |
| try: | |
| all_docs = list(vectorstore.docstore._dict.values()) | |
| bm25_retriever = BM25Retriever.from_documents(all_docs) | |
| bm25_retriever.k = 10 | |
| ensemble_retriever = EnsembleRetriever( | |
| retrievers=[faiss_retriever, bm25_retriever], | |
| weights=[0.3, 0.7] | |
| ) | |
| except: | |
| ensemble_retriever = faiss_retriever | |
| # 3. Reranking (Top 5 only) | |
| try: | |
| compressor = FlashrankRerank(model="ms-marco-MiniLM-L-12-v2", top_n=5) | |
| final_retriever = ContextualCompressionRetriever( | |
| base_compressor=compressor, | |
| base_retriever=ensemble_retriever | |
| ) | |
| except: | |
| final_retriever = ensemble_retriever | |
| # --- B. HISTORY AWARENESS --- | |
| # Reformulate question based on chat history | |
| rephrase_prompt = ChatPromptTemplate.from_template( | |
| """<s>[INST] Rephrase the follow-up question to be a standalone question. | |
| Chat History: {chat_history} | |
| Follow Up Input: {input} | |
| Standalone question: [/INST]""" | |
| ) | |
| history_node = create_history_aware_retriever( | |
| llm_pipeline, | |
| final_retriever, | |
| rephrase_prompt | |
| ) | |
| # --- C. FINAL ANSWER GENERATION (With Sources) --- | |
| qa_prompt = ChatPromptTemplate.from_template( | |
| """[INST] You are a helpful assistant for BPCL-Kochi Refinery. | |
| Answer the user's question based strictly on the context provided below. | |
| If the answer is not in the context, say "I don't have that information in the manuals." | |
| ALWAYS cite the document name for your answer. | |
| CONTEXT WITH SOURCES: | |
| {context} | |
| USER QUESTION: | |
| {input} | |
| ANSWER: [/INST]""" | |
| ) | |
| # The Chain (No Cache) | |
| qa_chain = ( | |
| { | |
| "context": history_node | format_docs_with_sources, | |
| "input": itemgetter("input"), | |
| "chat_history": itemgetter("chat_history"), | |
| } | |
| | qa_prompt | |
| | llm_pipeline | |
| | StrOutputParser() | |
| ) | |
| print("β Production Chain Built (with Citations)\n") | |
| # ------------------------------------------------------------------------- | |
| # 6. FastAPI App & Endpoints | |
| # ------------------------------------------------------------------------- | |
| async def lifespan(app: FastAPI): | |
| print("\nπ Starting application (OFFLINE)...") | |
| load_system() | |
| logger.info("RAG system initialized (OFFLINE)") | |
| yield | |
| print("\nπ Shutting down...") | |
| answer_cache.clear() | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| logger.info("Shutdown complete") | |
| app = FastAPI( | |
| title="BeRU Chat Assistant - VLM2Vec", | |
| description="100% Offline RAG system with VLM2Vec embeddings", | |
| version="2.0-VLM2Vec", | |
| lifespan=lifespan, | |
| ) | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| class ChatRequest(BaseModel): | |
| message: str = Field(..., min_length=1, max_length=2000) | |
| mode: str = "Detailed" | |
| session_id: Optional[str] = "default" | |
| include_images: bool = False | |
| def sanitize_message(cls, v): | |
| return v.strip() | |
| def validate_mode(cls, v): | |
| if v not in PROMPT_TEMPLATES: | |
| return "Detailed" | |
| return v | |
| class QueryRequest(BaseModel): | |
| message: str = Field(..., min_length=1, max_length=2000) | |
| answer_style: str = "Detailed" | |
| num_sources: int = Field(default=5, ge=1, le=10) | |
| def sanitize_message(cls, v): | |
| return v.strip() | |
| def validate_style(cls, v): | |
| if v not in PROMPT_TEMPLATES: | |
| return "Detailed" | |
| return v | |
| async def root(): | |
| try: | |
| frontend_path = Path("frontend.html") | |
| if frontend_path.exists(): | |
| with open(frontend_path, "r", encoding="utf-8") as f: | |
| return f.read() | |
| else: | |
| return f""" | |
| <html> | |
| <body> | |
| <h1>Error: frontend.html not found</h1> | |
| <p>Please place frontend.html in the same directory as this script</p> | |
| <p>Current directory: {Path.cwd()}</p> | |
| </body> | |
| </html> | |
| """ | |
| except Exception as e: | |
| return f"<html><body><h1>Error loading frontend</h1><p>{str(e)}</p></body></html>" | |
| query_semaphore = asyncio.Semaphore(3) | |
| async def chat_endpoint(request: ChatRequest): | |
| async with query_semaphore: | |
| try: | |
| message = request.message | |
| mode = request.mode | |
| session_id = request.session_id | |
| logger.info(f"Chat Query: {message[:100]} | Mode: {mode}") | |
| print(f"\n{'=' * 60}") | |
| print(f"π¬ Chat: {message}") | |
| print(f" Mode: {mode}") | |
| print(f" Session: {session_id}") | |
| # History Management | |
| if session_id not in conversations: | |
| conversations[session_id] = [] | |
| # Check Cache | |
| cache_key = f"{message}_{mode}_{session_id}" | |
| if cache_key in answer_cache: | |
| print("πΎ Cache hit!") | |
| cached_response = answer_cache[cache_key] | |
| conversations[session_id].append( | |
| { | |
| "user": message, | |
| "bot": cached_response["response"], | |
| "mode": mode, | |
| } | |
| ) | |
| return JSONResponse(cached_response) | |
| print(f"β±οΈ Generating response (timeout: {GENERATION_TIMEOUT}s)...") | |
| # Convert dict history to LangChain Objects (Last 3 turns) | |
| chat_history_objs = [] | |
| for turn in conversations[session_id][-3:]: | |
| # Ensure you have these imported from langchain_core.messages | |
| chat_history_objs.append(HumanMessage(content=turn["user"])) | |
| chat_history_objs.append(AIMessage(content=turn["bot"])) | |
| # Execute Chain | |
| try: | |
| result = await asyncio.wait_for( | |
| asyncio.to_thread( | |
| qa_chain.invoke, | |
| { | |
| "input": message, | |
| "chat_history": chat_history_objs | |
| }, | |
| ), | |
| timeout=GENERATION_TIMEOUT, | |
| ) | |
| except asyncio.TimeoutError: | |
| return JSONResponse( | |
| { | |
| "error": f"Query timeout after {GENERATION_TIMEOUT}s", | |
| "response": "Sorry, the request took too long. Please try again.", | |
| }, | |
| status_code=504, | |
| ) | |
| # --- CRITICAL FIX START --- | |
| # The new chain returns a String directly. The old one returned a Dict. | |
| # We must handle both cases to prevent the AttributeError. | |
| context_docs = [] # Default to empty if using string chain | |
| if isinstance(result, str): | |
| # New "Production Chain" path | |
| answer = result | |
| # Note: In this mode, citations are embedded in the text string | |
| # (e.g. "Reference: Manual..."), so we don't have raw docs for the 'sources' list. | |
| elif isinstance(result, dict): | |
| # Old "Standard Chain" path | |
| answer = result.get("answer", "No answer generated") | |
| context_docs = result.get("context", []) | |
| else: | |
| answer = str(result) | |
| # Clean up the answer text | |
| answer = structure_answer(answer, mode) | |
| # --- CRITICAL FIX END --- | |
| # Process Sources (Only populates if context_docs were returned) | |
| sources = [] | |
| for i, doc in enumerate(context_docs[:5], 1): | |
| sources.append( | |
| { | |
| "index": i, | |
| "file_name": doc.metadata.get("source", "Unknown"), | |
| "page": doc.metadata.get("page", "N/A"), | |
| "snippet": doc.page_content[:200].replace("\n", " "), | |
| } | |
| ) | |
| print(f"β Response generated: {len(answer)} chars") | |
| response_data = { | |
| "response": answer, | |
| "sources": sources, | |
| "mode": mode, | |
| "cached": False, | |
| "images": [] # Placeholder for image handling | |
| } | |
| answer_cache[cache_key] = response_data | |
| conversations[session_id].append( | |
| { | |
| "user": message, | |
| "bot": answer, | |
| "mode": mode, | |
| } | |
| ) | |
| logger.info("Chat response completed") | |
| return JSONResponse(response_data) | |
| except Exception as e: | |
| logger.error(f"Chat error: {e}", exc_info=True) | |
| print(f"β ERROR: {e}") | |
| # Ensure traceback is printed to console for debugging | |
| import traceback | |
| traceback.print_exc() | |
| return JSONResponse( | |
| { | |
| "error": str(e), | |
| "response": "Sorry, an internal error occurred. Please check server logs.", | |
| }, | |
| status_code=500, | |
| ) | |
| async def query_endpoint(request: QueryRequest): | |
| chat_request = ChatRequest( | |
| message=request.message, | |
| mode=request.answer_style, | |
| session_id="default", | |
| ) | |
| response = await chat_endpoint(chat_request) | |
| data = response.body.decode("utf-8") | |
| import json | |
| json_data = json.loads(data) | |
| if "response" in json_data: | |
| json_data["answer"] = json_data.pop("response") | |
| return JSONResponse(json_data) | |
| async def health_check(): | |
| return { | |
| "status": "ok", | |
| "mode": "OFFLINE", | |
| "llm_model": LLM_MODEL, | |
| "embedding_model": EMBED_MODEL, | |
| "cuda_available": torch.cuda.is_available(), | |
| "cache_size": len(answer_cache), | |
| "active_sessions": len(conversations), | |
| } | |
| async def get_stats(): | |
| try: | |
| doc_count = len(vectorstore.docstore._dict) if vectorstore else 0 | |
| except Exception: | |
| doc_count = "unknown" | |
| return { | |
| "mode": "OFFLINE", | |
| "documents": doc_count, | |
| "cache_size": len(answer_cache), | |
| "active_sessions": len(conversations), | |
| "llm_model": LLM_MODEL, | |
| "embedding_model": EMBED_MODEL, | |
| "cuda_available": torch.cuda.is_available(), | |
| "index_path": FAISS_INDEX_PATH, | |
| } | |
| async def new_conversation(request: dict): | |
| session_id = request.get("session_id", "default") | |
| if session_id in conversations: | |
| conversations[session_id] = [] | |
| return {"message": "New conversation started", "session_id": session_id} | |
| async def get_conversation(session_id: str): | |
| if session_id in conversations: | |
| return {"history": conversations[session_id]} | |
| return {"history": []} | |
| async def clear_cache(): | |
| cache_size = len(answer_cache) | |
| answer_cache.clear() | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| return {"message": f"Cache cleared. Removed {cache_size} entries"} | |
| if __name__ == "__main__": | |
| import sys | |
| import argparse | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--port", type=int, default=8001, help="Port to run the server on") | |
| args = parser.parse_args() | |
| port = args.port | |
| print("\n" + "=" * 70) | |
| print("π BeRU Chat Assistant - VLM2Vec Mode (100% OFFLINE)") | |
| print("=" * 70) | |
| print(f"\nπ Frontend: http://localhost:{port}") | |
| print(f"π API Docs: http://localhost:{port}/docs") | |
| print(f"π Health: http://localhost:{port}/api/health") | |
| print(f"π Stats: http://localhost:{port}/api/stats") | |
| print(f"\nπ Embedding Model (LOCAL): {EMBED_MODEL_PATH}") | |
| print(f"π LLM Model (LOCAL): {LLM_MODEL_PATH}") | |
| print(f"π FAISS Index: {FAISS_INDEX_PATH}") | |
| print("π Mode: 100% OFFLINE (local files only)") | |
| print("=" * 70 + "\n") | |
| uvicorn.run(app, host="0.0.0.0", port=port, log_level="info") |