rag / rag_service.py
pratyoos's picture
Upload 6 files
79a28b6 verified
from pathlib import Path
import pickle
import time
import importlib.util
import faiss
import numpy as np
import torch
from huggingface_hub import hf_hub_download
from InstructorEmbedding import INSTRUCTOR
from transformers import AutoModelForCausalLM, AutoTokenizer
from config import settings
class _CompatDocument:
"""Fallback placeholder for pickled langchain Document objects."""
pass
class _CompatUnpickler(pickle.Unpickler):
"""Map langchain document class to a lightweight local placeholder."""
def find_class(self, module, name):
if module == "langchain_core.documents.base" and name == "Document":
return _CompatDocument
return super().find_class(module, name)
def _load_chunks(path: Path):
"""Load chunks.pkl with normal pickle, then fallback if langchain_core is absent."""
with path.open("rb") as f:
try:
return pickle.load(f)
except ModuleNotFoundError as e:
if e.name != "langchain_core":
raise
with path.open("rb") as f:
return _CompatUnpickler(f).load()
def _chunk_payload(chunk):
"""Return the serialized payload for both real and fallback document objects."""
if hasattr(chunk, "page_content") and hasattr(chunk, "metadata"):
return {
"page_content": chunk.page_content,
"metadata": chunk.metadata,
}
raw = getattr(chunk, "__dict__", {})
nested = raw.get("__dict__", raw)
if isinstance(nested, dict):
return nested
return {}
def _chunk_page_content(chunk):
return _chunk_payload(chunk).get("page_content", "")
def _chunk_metadata(chunk):
return _chunk_payload(chunk).get("metadata", {})
def _trim_chunk_text(text: str) -> str:
limit = max(0, settings.max_chars_per_chunk)
if limit == 0 or len(text) <= limit:
return text
return text[:limit].rstrip() + "\n...[truncated]"
def find_data_file(filename: str) -> Path:
explicit = Path(filename)
if explicit.is_absolute() and explicit.exists():
return explicit
for root in settings.data_search_roots:
candidate = root / filename
if candidate.exists():
return candidate
raise FileNotFoundError(f"Could not find {filename} in expected locations")
def resolve_data_file(filename: str) -> Path:
try:
return find_data_file(filename)
except FileNotFoundError:
if not settings.allow_hf_assets_download:
raise
if not settings.hf_assets_repo_id:
raise FileNotFoundError(
f"Could not find {filename} locally and HF_ASSETS_REPO_ID is not configured"
)
subdir = settings.hf_assets_subdir.strip("/")
preferred_filename = f"{subdir}/{filename}" if subdir else filename
fallback_filename = filename
attempts = [preferred_filename]
if fallback_filename != preferred_filename:
attempts.append(fallback_filename)
last_error = None
for candidate in attempts:
try:
downloaded = hf_hub_download(
repo_id=settings.hf_assets_repo_id,
filename=candidate,
repo_type="model",
)
print(f"Downloaded {candidate} from {settings.hf_assets_repo_id}")
return Path(downloaded)
except Exception as exc:
last_error = exc
raise FileNotFoundError(
f"Could not find {filename} locally or in Hugging Face repo {settings.hf_assets_repo_id}"
) from last_error
class AppState:
def __init__(self):
self.device = "cuda:0" if torch.cuda.is_available() else "cpu"
self.model_id = settings.model_id
self.model_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
self.index = None
self.chunks = None
self.embedding_model = None
self.model = None
self.tokenizer = None
self.startup_timing = {}
state = AppState()
def _has_accelerate() -> bool:
return importlib.util.find_spec("accelerate") is not None
def _load_causal_lm(model_id: str, dtype: torch.dtype, device: str):
"""Load a causal LM with an accelerate-aware device placement fallback."""
base_kwargs = {"dtype": dtype}
if _has_accelerate():
base_kwargs["device_map"] = {"": device}
try:
return AutoModelForCausalLM.from_pretrained(model_id, **base_kwargs)
except TypeError:
# Backward compatibility for transformers versions that still expect torch_dtype.
fallback_kwargs = dict(base_kwargs)
fallback_kwargs.pop("dtype", None)
fallback_kwargs["torch_dtype"] = dtype
return AutoModelForCausalLM.from_pretrained(model_id, **fallback_kwargs)
def retrieve_chunks(query: str, k: int) -> list:
query_embedding = state.embedding_model.encode([[settings.retrieval_instruction, query]])[0]
query_vector = np.array([query_embedding]).astype("float32")
_distances, indices = state.index.search(query_vector, k)
return [state.chunks[i] for i in indices[0]]
def generate_answer(question: str, retrieved_chunks: list) -> str:
context = ""
for i, chunk in enumerate(retrieved_chunks):
chunk_text = _trim_chunk_text(_chunk_page_content(chunk))
context += f"Source {i + 1}:\n{chunk_text}\n\n"
messages = [
{
"role": "system",
"content": (
"You are a helpful assistant that answers questions using ONLY the provided sources. "
"Synthesize information from ALL sources given. "
"Give a complete and coherent answer. "
"Do not cut off mid sentence. "
"If the sources do not contain enough information say so clearly."
),
},
{
"role": "user",
"content": (
f"Question: {question}\n\n"
f"{context}"
"Based on ALL the sources above provide a complete answer to the question."
),
},
]
text = state.tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
)
inputs = state.tokenizer(
text,
return_tensors="pt",
truncation=True,
max_length=settings.max_context_tokens,
).to(state.device)
generation_kwargs = {
"max_new_tokens": settings.max_new_tokens,
"do_sample": settings.do_sample,
"pad_token_id": state.tokenizer.eos_token_id,
"repetition_penalty": settings.repetition_penalty,
}
if settings.do_sample:
generation_kwargs["temperature"] = settings.temperature
with torch.inference_mode():
output = state.model.generate(
**inputs,
**generation_kwargs,
)
generated_tokens = output[0][inputs["input_ids"].shape[1] :]
return state.tokenizer.decode(generated_tokens, skip_special_tokens=True).strip()
def rag_query(question: str, k: int) -> dict:
t0 = time.perf_counter()
t_retrieve_start = time.perf_counter()
retrieved = retrieve_chunks(question, k=k)
retrieval_time = time.perf_counter() - t_retrieve_start
t_generate_start = time.perf_counter()
answer = generate_answer(question, retrieved)
generation_time = time.perf_counter() - t_generate_start
total_time = time.perf_counter() - t0
sources = [_chunk_metadata(chunk).get("url", "") for chunk in retrieved]
return {
"question": question,
"answer": answer,
"sources": sources,
"timing": {
"retrieval_seconds": retrieval_time,
"generation_seconds": generation_time,
"total_seconds": total_time,
},
}
def preload() -> dict:
t0 = time.perf_counter()
print(f"Using device : {state.device}")
if torch.cuda.is_available():
gpu_name = torch.cuda.get_device_name(0)
print(f"CUDA available : True ({gpu_name})")
if torch.cuda.is_bf16_supported():
state.model_dtype = torch.bfloat16
else:
state.model_dtype = torch.float16
print(f"Model dtype : {state.model_dtype}")
else:
print("CUDA available : False")
state.model_dtype = torch.float32
print("Loading vector DB...")
t_index = time.perf_counter()
index_path = resolve_data_file(settings.vector_db_file)
state.index = faiss.read_index(str(index_path))
index_time = time.perf_counter() - t_index
print(f"Index loaded : {state.index.ntotal} vectors")
print("Loading chunks...")
t_chunks = time.perf_counter()
chunks_path = resolve_data_file(settings.chunks_file)
state.chunks = _load_chunks(chunks_path)
chunks_time = time.perf_counter() - t_chunks
print(f"Chunks loaded : {len(state.chunks)}")
print("Loading embedding model...")
t_embed = time.perf_counter()
state.embedding_model = INSTRUCTOR(settings.embedding_model_id)
if torch.cuda.is_available():
try:
state.embedding_model.to(state.device)
except Exception:
# Some InstructorEmbedding backends do not expose .to(); keep CPU fallback.
pass
embedding_time = time.perf_counter() - t_embed
print(f"Loading {settings.model_id}...")
t_model = time.perf_counter()
state.model = _load_causal_lm(
model_id=settings.model_id,
dtype=state.model_dtype,
device=state.device,
)
state.tokenizer = AutoTokenizer.from_pretrained(settings.model_id)
if not _has_accelerate():
state.model.to(state.device)
state.model.eval()
model_time = time.perf_counter() - t_model
first_param_device = str(next(state.model.parameters()).device)
print(f"LLM loaded on : {first_param_device}")
total_startup = time.perf_counter() - t0
state.startup_timing = {
"index_load_seconds": index_time,
"chunks_load_seconds": chunks_time,
"embedding_model_load_seconds": embedding_time,
"llm_load_seconds": model_time,
"total_startup_seconds": total_startup,
}
print("RAG API preloaded successfully")
print(
f"Startup timing: total={total_startup:.2f}s, index={index_time:.2f}s, "
f"chunks={chunks_time:.2f}s, embedding={embedding_time:.2f}s, model={model_time:.2f}s"
)
return state.startup_timing