| from pathlib import Path |
| import pickle |
| import time |
| import importlib.util |
|
|
| import faiss |
| import numpy as np |
| import torch |
| from huggingface_hub import hf_hub_download |
| from InstructorEmbedding import INSTRUCTOR |
| from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
| from config import settings |
|
|
|
|
| class _CompatDocument: |
| """Fallback placeholder for pickled langchain Document objects.""" |
|
|
| pass |
|
|
|
|
| class _CompatUnpickler(pickle.Unpickler): |
| """Map langchain document class to a lightweight local placeholder.""" |
|
|
| def find_class(self, module, name): |
| if module == "langchain_core.documents.base" and name == "Document": |
| return _CompatDocument |
| return super().find_class(module, name) |
|
|
|
|
| def _load_chunks(path: Path): |
| """Load chunks.pkl with normal pickle, then fallback if langchain_core is absent.""" |
| with path.open("rb") as f: |
| try: |
| return pickle.load(f) |
| except ModuleNotFoundError as e: |
| if e.name != "langchain_core": |
| raise |
|
|
| with path.open("rb") as f: |
| return _CompatUnpickler(f).load() |
|
|
|
|
| def _chunk_payload(chunk): |
| """Return the serialized payload for both real and fallback document objects.""" |
| if hasattr(chunk, "page_content") and hasattr(chunk, "metadata"): |
| return { |
| "page_content": chunk.page_content, |
| "metadata": chunk.metadata, |
| } |
|
|
| raw = getattr(chunk, "__dict__", {}) |
| nested = raw.get("__dict__", raw) |
| if isinstance(nested, dict): |
| return nested |
| return {} |
|
|
|
|
| def _chunk_page_content(chunk): |
| return _chunk_payload(chunk).get("page_content", "") |
|
|
|
|
| def _chunk_metadata(chunk): |
| return _chunk_payload(chunk).get("metadata", {}) |
|
|
|
|
| def _trim_chunk_text(text: str) -> str: |
| limit = max(0, settings.max_chars_per_chunk) |
| if limit == 0 or len(text) <= limit: |
| return text |
| return text[:limit].rstrip() + "\n...[truncated]" |
|
|
|
|
| def find_data_file(filename: str) -> Path: |
| explicit = Path(filename) |
| if explicit.is_absolute() and explicit.exists(): |
| return explicit |
|
|
| for root in settings.data_search_roots: |
| candidate = root / filename |
| if candidate.exists(): |
| return candidate |
| raise FileNotFoundError(f"Could not find {filename} in expected locations") |
|
|
|
|
| def resolve_data_file(filename: str) -> Path: |
| try: |
| return find_data_file(filename) |
| except FileNotFoundError: |
| if not settings.allow_hf_assets_download: |
| raise |
|
|
| if not settings.hf_assets_repo_id: |
| raise FileNotFoundError( |
| f"Could not find {filename} locally and HF_ASSETS_REPO_ID is not configured" |
| ) |
|
|
| subdir = settings.hf_assets_subdir.strip("/") |
| preferred_filename = f"{subdir}/{filename}" if subdir else filename |
| fallback_filename = filename |
| attempts = [preferred_filename] |
| if fallback_filename != preferred_filename: |
| attempts.append(fallback_filename) |
|
|
| last_error = None |
| for candidate in attempts: |
| try: |
| downloaded = hf_hub_download( |
| repo_id=settings.hf_assets_repo_id, |
| filename=candidate, |
| repo_type="model", |
| ) |
| print(f"Downloaded {candidate} from {settings.hf_assets_repo_id}") |
| return Path(downloaded) |
| except Exception as exc: |
| last_error = exc |
|
|
| raise FileNotFoundError( |
| f"Could not find {filename} locally or in Hugging Face repo {settings.hf_assets_repo_id}" |
| ) from last_error |
|
|
|
|
| class AppState: |
| def __init__(self): |
| self.device = "cuda:0" if torch.cuda.is_available() else "cpu" |
| self.model_id = settings.model_id |
| self.model_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 |
|
|
| self.index = None |
| self.chunks = None |
| self.embedding_model = None |
| self.model = None |
| self.tokenizer = None |
| self.startup_timing = {} |
|
|
|
|
| state = AppState() |
|
|
|
|
| def _has_accelerate() -> bool: |
| return importlib.util.find_spec("accelerate") is not None |
|
|
|
|
| def _load_causal_lm(model_id: str, dtype: torch.dtype, device: str): |
| """Load a causal LM with an accelerate-aware device placement fallback.""" |
| base_kwargs = {"dtype": dtype} |
|
|
| if _has_accelerate(): |
| base_kwargs["device_map"] = {"": device} |
|
|
| try: |
| return AutoModelForCausalLM.from_pretrained(model_id, **base_kwargs) |
| except TypeError: |
| |
| fallback_kwargs = dict(base_kwargs) |
| fallback_kwargs.pop("dtype", None) |
| fallback_kwargs["torch_dtype"] = dtype |
| return AutoModelForCausalLM.from_pretrained(model_id, **fallback_kwargs) |
|
|
|
|
| def retrieve_chunks(query: str, k: int) -> list: |
| query_embedding = state.embedding_model.encode([[settings.retrieval_instruction, query]])[0] |
| query_vector = np.array([query_embedding]).astype("float32") |
| _distances, indices = state.index.search(query_vector, k) |
| return [state.chunks[i] for i in indices[0]] |
|
|
|
|
| def generate_answer(question: str, retrieved_chunks: list) -> str: |
| context = "" |
| for i, chunk in enumerate(retrieved_chunks): |
| chunk_text = _trim_chunk_text(_chunk_page_content(chunk)) |
| context += f"Source {i + 1}:\n{chunk_text}\n\n" |
|
|
| messages = [ |
| { |
| "role": "system", |
| "content": ( |
| "You are a helpful assistant that answers questions using ONLY the provided sources. " |
| "Synthesize information from ALL sources given. " |
| "Give a complete and coherent answer. " |
| "Do not cut off mid sentence. " |
| "If the sources do not contain enough information say so clearly." |
| ), |
| }, |
| { |
| "role": "user", |
| "content": ( |
| f"Question: {question}\n\n" |
| f"{context}" |
| "Based on ALL the sources above provide a complete answer to the question." |
| ), |
| }, |
| ] |
|
|
| text = state.tokenizer.apply_chat_template( |
| messages, |
| tokenize=False, |
| add_generation_prompt=True, |
| ) |
|
|
| inputs = state.tokenizer( |
| text, |
| return_tensors="pt", |
| truncation=True, |
| max_length=settings.max_context_tokens, |
| ).to(state.device) |
|
|
| generation_kwargs = { |
| "max_new_tokens": settings.max_new_tokens, |
| "do_sample": settings.do_sample, |
| "pad_token_id": state.tokenizer.eos_token_id, |
| "repetition_penalty": settings.repetition_penalty, |
| } |
| if settings.do_sample: |
| generation_kwargs["temperature"] = settings.temperature |
|
|
| with torch.inference_mode(): |
| output = state.model.generate( |
| **inputs, |
| **generation_kwargs, |
| ) |
|
|
| generated_tokens = output[0][inputs["input_ids"].shape[1] :] |
| return state.tokenizer.decode(generated_tokens, skip_special_tokens=True).strip() |
|
|
|
|
| def rag_query(question: str, k: int) -> dict: |
| t0 = time.perf_counter() |
|
|
| t_retrieve_start = time.perf_counter() |
| retrieved = retrieve_chunks(question, k=k) |
| retrieval_time = time.perf_counter() - t_retrieve_start |
|
|
| t_generate_start = time.perf_counter() |
| answer = generate_answer(question, retrieved) |
| generation_time = time.perf_counter() - t_generate_start |
| total_time = time.perf_counter() - t0 |
|
|
| sources = [_chunk_metadata(chunk).get("url", "") for chunk in retrieved] |
| return { |
| "question": question, |
| "answer": answer, |
| "sources": sources, |
| "timing": { |
| "retrieval_seconds": retrieval_time, |
| "generation_seconds": generation_time, |
| "total_seconds": total_time, |
| }, |
| } |
|
|
|
|
| def preload() -> dict: |
| t0 = time.perf_counter() |
|
|
| print(f"Using device : {state.device}") |
| if torch.cuda.is_available(): |
| gpu_name = torch.cuda.get_device_name(0) |
| print(f"CUDA available : True ({gpu_name})") |
| if torch.cuda.is_bf16_supported(): |
| state.model_dtype = torch.bfloat16 |
| else: |
| state.model_dtype = torch.float16 |
| print(f"Model dtype : {state.model_dtype}") |
| else: |
| print("CUDA available : False") |
| state.model_dtype = torch.float32 |
|
|
| print("Loading vector DB...") |
| t_index = time.perf_counter() |
| index_path = resolve_data_file(settings.vector_db_file) |
| state.index = faiss.read_index(str(index_path)) |
| index_time = time.perf_counter() - t_index |
| print(f"Index loaded : {state.index.ntotal} vectors") |
|
|
| print("Loading chunks...") |
| t_chunks = time.perf_counter() |
| chunks_path = resolve_data_file(settings.chunks_file) |
| state.chunks = _load_chunks(chunks_path) |
| chunks_time = time.perf_counter() - t_chunks |
| print(f"Chunks loaded : {len(state.chunks)}") |
|
|
| print("Loading embedding model...") |
| t_embed = time.perf_counter() |
| state.embedding_model = INSTRUCTOR(settings.embedding_model_id) |
| if torch.cuda.is_available(): |
| try: |
| state.embedding_model.to(state.device) |
| except Exception: |
| |
| pass |
| embedding_time = time.perf_counter() - t_embed |
|
|
| print(f"Loading {settings.model_id}...") |
| t_model = time.perf_counter() |
| state.model = _load_causal_lm( |
| model_id=settings.model_id, |
| dtype=state.model_dtype, |
| device=state.device, |
| ) |
| state.tokenizer = AutoTokenizer.from_pretrained(settings.model_id) |
| if not _has_accelerate(): |
| state.model.to(state.device) |
| state.model.eval() |
| model_time = time.perf_counter() - t_model |
|
|
| first_param_device = str(next(state.model.parameters()).device) |
| print(f"LLM loaded on : {first_param_device}") |
|
|
| total_startup = time.perf_counter() - t0 |
| state.startup_timing = { |
| "index_load_seconds": index_time, |
| "chunks_load_seconds": chunks_time, |
| "embedding_model_load_seconds": embedding_time, |
| "llm_load_seconds": model_time, |
| "total_startup_seconds": total_startup, |
| } |
|
|
| print("RAG API preloaded successfully") |
| print( |
| f"Startup timing: total={total_startup:.2f}s, index={index_time:.2f}s, " |
| f"chunks={chunks_time:.2f}s, embedding={embedding_time:.2f}s, model={model_time:.2f}s" |
| ) |
| return state.startup_timing |
|
|