from pathlib import Path import pickle import time import importlib.util import faiss import numpy as np import torch from huggingface_hub import hf_hub_download from InstructorEmbedding import INSTRUCTOR from transformers import AutoModelForCausalLM, AutoTokenizer from config import settings class _CompatDocument: """Fallback placeholder for pickled langchain Document objects.""" pass class _CompatUnpickler(pickle.Unpickler): """Map langchain document class to a lightweight local placeholder.""" def find_class(self, module, name): if module == "langchain_core.documents.base" and name == "Document": return _CompatDocument return super().find_class(module, name) def _load_chunks(path: Path): """Load chunks.pkl with normal pickle, then fallback if langchain_core is absent.""" with path.open("rb") as f: try: return pickle.load(f) except ModuleNotFoundError as e: if e.name != "langchain_core": raise with path.open("rb") as f: return _CompatUnpickler(f).load() def _chunk_payload(chunk): """Return the serialized payload for both real and fallback document objects.""" if hasattr(chunk, "page_content") and hasattr(chunk, "metadata"): return { "page_content": chunk.page_content, "metadata": chunk.metadata, } raw = getattr(chunk, "__dict__", {}) nested = raw.get("__dict__", raw) if isinstance(nested, dict): return nested return {} def _chunk_page_content(chunk): return _chunk_payload(chunk).get("page_content", "") def _chunk_metadata(chunk): return _chunk_payload(chunk).get("metadata", {}) def _trim_chunk_text(text: str) -> str: limit = max(0, settings.max_chars_per_chunk) if limit == 0 or len(text) <= limit: return text return text[:limit].rstrip() + "\n...[truncated]" def find_data_file(filename: str) -> Path: explicit = Path(filename) if explicit.is_absolute() and explicit.exists(): return explicit for root in settings.data_search_roots: candidate = root / filename if candidate.exists(): return candidate raise FileNotFoundError(f"Could not find {filename} in expected locations") def resolve_data_file(filename: str) -> Path: try: return find_data_file(filename) except FileNotFoundError: if not settings.allow_hf_assets_download: raise if not settings.hf_assets_repo_id: raise FileNotFoundError( f"Could not find {filename} locally and HF_ASSETS_REPO_ID is not configured" ) subdir = settings.hf_assets_subdir.strip("/") preferred_filename = f"{subdir}/{filename}" if subdir else filename fallback_filename = filename attempts = [preferred_filename] if fallback_filename != preferred_filename: attempts.append(fallback_filename) last_error = None for candidate in attempts: try: downloaded = hf_hub_download( repo_id=settings.hf_assets_repo_id, filename=candidate, repo_type="model", ) print(f"Downloaded {candidate} from {settings.hf_assets_repo_id}") return Path(downloaded) except Exception as exc: last_error = exc raise FileNotFoundError( f"Could not find {filename} locally or in Hugging Face repo {settings.hf_assets_repo_id}" ) from last_error class AppState: def __init__(self): self.device = "cuda:0" if torch.cuda.is_available() else "cpu" self.model_id = settings.model_id self.model_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 self.index = None self.chunks = None self.embedding_model = None self.model = None self.tokenizer = None self.startup_timing = {} state = AppState() def _has_accelerate() -> bool: return importlib.util.find_spec("accelerate") is not None def _load_causal_lm(model_id: str, dtype: torch.dtype, device: str): """Load a causal LM with an accelerate-aware device placement fallback.""" base_kwargs = {"dtype": dtype} if _has_accelerate(): base_kwargs["device_map"] = {"": device} try: return AutoModelForCausalLM.from_pretrained(model_id, **base_kwargs) except TypeError: # Backward compatibility for transformers versions that still expect torch_dtype. fallback_kwargs = dict(base_kwargs) fallback_kwargs.pop("dtype", None) fallback_kwargs["torch_dtype"] = dtype return AutoModelForCausalLM.from_pretrained(model_id, **fallback_kwargs) def retrieve_chunks(query: str, k: int) -> list: query_embedding = state.embedding_model.encode([[settings.retrieval_instruction, query]])[0] query_vector = np.array([query_embedding]).astype("float32") _distances, indices = state.index.search(query_vector, k) return [state.chunks[i] for i in indices[0]] def generate_answer(question: str, retrieved_chunks: list) -> str: context = "" for i, chunk in enumerate(retrieved_chunks): chunk_text = _trim_chunk_text(_chunk_page_content(chunk)) context += f"Source {i + 1}:\n{chunk_text}\n\n" messages = [ { "role": "system", "content": ( "You are a helpful assistant that answers questions using ONLY the provided sources. " "Synthesize information from ALL sources given. " "Give a complete and coherent answer. " "Do not cut off mid sentence. " "If the sources do not contain enough information say so clearly." ), }, { "role": "user", "content": ( f"Question: {question}\n\n" f"{context}" "Based on ALL the sources above provide a complete answer to the question." ), }, ] text = state.tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True, ) inputs = state.tokenizer( text, return_tensors="pt", truncation=True, max_length=settings.max_context_tokens, ).to(state.device) generation_kwargs = { "max_new_tokens": settings.max_new_tokens, "do_sample": settings.do_sample, "pad_token_id": state.tokenizer.eos_token_id, "repetition_penalty": settings.repetition_penalty, } if settings.do_sample: generation_kwargs["temperature"] = settings.temperature with torch.inference_mode(): output = state.model.generate( **inputs, **generation_kwargs, ) generated_tokens = output[0][inputs["input_ids"].shape[1] :] return state.tokenizer.decode(generated_tokens, skip_special_tokens=True).strip() def rag_query(question: str, k: int) -> dict: t0 = time.perf_counter() t_retrieve_start = time.perf_counter() retrieved = retrieve_chunks(question, k=k) retrieval_time = time.perf_counter() - t_retrieve_start t_generate_start = time.perf_counter() answer = generate_answer(question, retrieved) generation_time = time.perf_counter() - t_generate_start total_time = time.perf_counter() - t0 sources = [_chunk_metadata(chunk).get("url", "") for chunk in retrieved] return { "question": question, "answer": answer, "sources": sources, "timing": { "retrieval_seconds": retrieval_time, "generation_seconds": generation_time, "total_seconds": total_time, }, } def preload() -> dict: t0 = time.perf_counter() print(f"Using device : {state.device}") if torch.cuda.is_available(): gpu_name = torch.cuda.get_device_name(0) print(f"CUDA available : True ({gpu_name})") if torch.cuda.is_bf16_supported(): state.model_dtype = torch.bfloat16 else: state.model_dtype = torch.float16 print(f"Model dtype : {state.model_dtype}") else: print("CUDA available : False") state.model_dtype = torch.float32 print("Loading vector DB...") t_index = time.perf_counter() index_path = resolve_data_file(settings.vector_db_file) state.index = faiss.read_index(str(index_path)) index_time = time.perf_counter() - t_index print(f"Index loaded : {state.index.ntotal} vectors") print("Loading chunks...") t_chunks = time.perf_counter() chunks_path = resolve_data_file(settings.chunks_file) state.chunks = _load_chunks(chunks_path) chunks_time = time.perf_counter() - t_chunks print(f"Chunks loaded : {len(state.chunks)}") print("Loading embedding model...") t_embed = time.perf_counter() state.embedding_model = INSTRUCTOR(settings.embedding_model_id) if torch.cuda.is_available(): try: state.embedding_model.to(state.device) except Exception: # Some InstructorEmbedding backends do not expose .to(); keep CPU fallback. pass embedding_time = time.perf_counter() - t_embed print(f"Loading {settings.model_id}...") t_model = time.perf_counter() state.model = _load_causal_lm( model_id=settings.model_id, dtype=state.model_dtype, device=state.device, ) state.tokenizer = AutoTokenizer.from_pretrained(settings.model_id) if not _has_accelerate(): state.model.to(state.device) state.model.eval() model_time = time.perf_counter() - t_model first_param_device = str(next(state.model.parameters()).device) print(f"LLM loaded on : {first_param_device}") total_startup = time.perf_counter() - t0 state.startup_timing = { "index_load_seconds": index_time, "chunks_load_seconds": chunks_time, "embedding_model_load_seconds": embedding_time, "llm_load_seconds": model_time, "total_startup_seconds": total_startup, } print("RAG API preloaded successfully") print( f"Startup timing: total={total_startup:.2f}s, index={index_time:.2f}s, " f"chunks={chunks_time:.2f}s, embedding={embedding_time:.2f}s, model={model_time:.2f}s" ) return state.startup_timing