Pujan-Dev commited on
Commit
128b0a8
·
verified ·
1 Parent(s): 12e4d25

Upload 8 files

Browse files
Files changed (6) hide show
  1. .gitignore +1 -0
  2. Dockerfile +39 -0
  3. config.py +93 -0
  4. main.py +37 -0
  5. rag_service.py +248 -0
  6. schemas.py +26 -0
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ __pycache__
Dockerfile ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.11-slim
2
+
3
+ # Keep Python output unbuffered and avoid .pyc files in containers.
4
+ ENV PYTHONDONTWRITEBYTECODE=1
5
+ ENV PYTHONUNBUFFERED=1
6
+ ENV PIP_NO_CACHE_DIR=1
7
+
8
+ # Optional Hugging Face cache location inside the container.
9
+ ENV HF_HOME=/app/.cache/huggingface
10
+ ENV TRANSFORMERS_CACHE=/app/.cache/huggingface
11
+
12
+ WORKDIR /app
13
+
14
+ # System libs often needed by ML wheels/runtime.
15
+ RUN apt-get update && apt-get install -y --no-install-recommends \
16
+ git \
17
+ build-essential \
18
+ && rm -rf /var/lib/apt/lists/*
19
+
20
+ # Install Python dependencies used by Fastapi/main.py.
21
+ RUN pip install --upgrade pip && pip install \
22
+ fastapi \
23
+ "uvicorn[standard]" \
24
+ numpy \
25
+ faiss-cpu \
26
+ torch \
27
+ transformers \
28
+ sentencepiece \
29
+ InstructorEmbedding \
30
+ langchain-core
31
+
32
+ # Copy the whole repo so Fastapi app can resolve vector_db.index/chunks.pkl
33
+ # from /app, /app/Fastapi, or /app/RAG_pipeline.
34
+ COPY . /app
35
+
36
+ EXPOSE 8000
37
+
38
+ # Run FastAPI app.
39
+ CMD ["uvicorn", "Fastapi.main:app", "--host", "0.0.0.0", "--port", "8000"]
config.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass, field
2
+ from pathlib import Path
3
+ import os
4
+
5
+
6
+ def _load_dotenv(dotenv_path: Path) -> None:
7
+ if not dotenv_path.exists():
8
+ return
9
+
10
+ for raw_line in dotenv_path.read_text(encoding="utf-8").splitlines():
11
+ line = raw_line.strip()
12
+ if not line or line.startswith("#") or "=" not in line:
13
+ continue
14
+
15
+ key, value = line.split("=", 1)
16
+ key = key.strip()
17
+ value = value.strip().strip('"').strip("'")
18
+ os.environ.setdefault(key, value)
19
+
20
+
21
+ def _get_env(name: str, default: str, aliases: tuple[str, ...] = ()) -> str:
22
+ for key in (name, *aliases):
23
+ value = os.getenv(key)
24
+ if value is not None and value != "":
25
+ return value
26
+ return default
27
+
28
+
29
+ def _to_int(value: str, default: int) -> int:
30
+ try:
31
+ return int(value)
32
+ except (TypeError, ValueError):
33
+ return default
34
+
35
+
36
+ def _to_float(value: str, default: float) -> float:
37
+ try:
38
+ return float(value)
39
+ except (TypeError, ValueError):
40
+ return default
41
+
42
+
43
+ _BASE_DIR = Path(__file__).resolve().parent
44
+ _load_dotenv(_BASE_DIR / ".env")
45
+
46
+
47
+ @dataclass
48
+ class Settings:
49
+ app_title: str = _get_env("APP_TITLE", "RAG API")
50
+ model_id: str = _get_env("MODEL_ID", "Qwen/Qwen2.5-1.5B-Instruct", aliases=("MODEL_NAME",))
51
+ embedding_model_id: str = _get_env(
52
+ "EMBEDDING_MODEL_ID",
53
+ "hkunlp/instructor-base",
54
+ aliases=("EMBEDDING_MODEL",),
55
+ )
56
+
57
+ models_dir: str = _get_env("MODELS_DIR", "Models")
58
+ vector_db_file: str = _get_env("VECTOR_DB_FILE", "vector_db.index", aliases=("VECTOR_STORE_PATH",))
59
+ chunks_file: str = _get_env("CHUNKS_FILE", "chunks.pkl")
60
+
61
+ retrieval_instruction: str = _get_env(
62
+ "RETRIEVAL_INSTRUCTION",
63
+ "Represent the question for retrieving relevant documents",
64
+ )
65
+
66
+ max_context_tokens: int = _to_int(_get_env("MAX_CONTEXT_TOKENS", "3072"), 3072)
67
+ max_new_tokens: int = _to_int(_get_env("MAX_NEW_TOKENS", "500"), 500)
68
+ temperature: float = _to_float(_get_env("TEMPERATURE", "0.3"), 0.3)
69
+ repetition_penalty: float = _to_float(_get_env("REPETITION_PENALTY", "1.3"), 1.3)
70
+
71
+ default_top_k: int = _to_int(_get_env("DEFAULT_TOP_K", "3"), 3)
72
+ min_top_k: int = _to_int(_get_env("MIN_TOP_K", "1"), 1)
73
+ max_top_k: int = _to_int(_get_env("MAX_TOP_K", "10"), 10)
74
+
75
+ host: str = _get_env("HOST", "0.0.0.0", aliases=("API_HOST",))
76
+ port: int = _to_int(_get_env("PORT", "8000", aliases=("API_PORT",)), 8000)
77
+
78
+ base_dir: Path = field(default_factory=lambda: _BASE_DIR)
79
+
80
+ @property
81
+ def data_search_roots(self) -> list[Path]:
82
+ models_path = Path(self.models_dir)
83
+ return [
84
+ self.base_dir / models_path,
85
+ self.base_dir,
86
+ self.base_dir.parent / models_path,
87
+ self.base_dir.parent / "RAG_pipeline" / models_path,
88
+ self.base_dir.parent / "RAG_pipeline",
89
+ self.base_dir.parent,
90
+ ]
91
+
92
+
93
+ settings = Settings()
main.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, HTTPException
2
+ from contextlib import asynccontextmanager
3
+
4
+ from config import settings
5
+ from rag_service import preload, rag_query, state
6
+ from schemas import QueryRequest, QueryResponse
7
+
8
+
9
+ @asynccontextmanager
10
+ async def lifespan(_app: FastAPI):
11
+ preload()
12
+ yield
13
+
14
+
15
+ app = FastAPI(title=settings.app_title, lifespan=lifespan)
16
+
17
+
18
+ @app.get("/")
19
+ def root():
20
+ model_runtime_device = None
21
+ if state.model is not None:
22
+ model_runtime_device = str(next(state.model.parameters()).device)
23
+ return {
24
+ "message": "RAG API is running",
25
+ "device": state.device,
26
+ "model_runtime_device": model_runtime_device,
27
+ "model_dtype": str(state.model_dtype),
28
+ "startup_timing": state.startup_timing,
29
+ }
30
+
31
+
32
+ @app.post("/query", response_model=QueryResponse)
33
+ def query(payload: QueryRequest):
34
+ if state.index is None or state.embedding_model is None or state.model is None:
35
+ raise HTTPException(status_code=503, detail="Model is not loaded yet")
36
+ result = rag_query(payload.question, k=payload.k)
37
+ return QueryResponse(**result)
rag_service.py ADDED
@@ -0,0 +1,248 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ import pickle
3
+ import time
4
+
5
+ import faiss
6
+ import numpy as np
7
+ import torch
8
+ from InstructorEmbedding import INSTRUCTOR
9
+ from transformers import AutoModelForCausalLM, AutoTokenizer
10
+
11
+ from config import settings
12
+
13
+
14
+ class _CompatDocument:
15
+ """Fallback placeholder for pickled langchain Document objects."""
16
+
17
+ pass
18
+
19
+
20
+ class _CompatUnpickler(pickle.Unpickler):
21
+ """Map langchain document class to a lightweight local placeholder."""
22
+
23
+ def find_class(self, module, name):
24
+ if module == "langchain_core.documents.base" and name == "Document":
25
+ return _CompatDocument
26
+ return super().find_class(module, name)
27
+
28
+
29
+ def _load_chunks(path: Path):
30
+ """Load chunks.pkl with normal pickle, then fallback if langchain_core is absent."""
31
+ with path.open("rb") as f:
32
+ try:
33
+ return pickle.load(f)
34
+ except ModuleNotFoundError as e:
35
+ if e.name != "langchain_core":
36
+ raise
37
+
38
+ with path.open("rb") as f:
39
+ return _CompatUnpickler(f).load()
40
+
41
+
42
+ def _chunk_payload(chunk):
43
+ """Return the serialized payload for both real and fallback document objects."""
44
+ if hasattr(chunk, "page_content") and hasattr(chunk, "metadata"):
45
+ return {
46
+ "page_content": chunk.page_content,
47
+ "metadata": chunk.metadata,
48
+ }
49
+
50
+ raw = getattr(chunk, "__dict__", {})
51
+ nested = raw.get("__dict__", raw)
52
+ if isinstance(nested, dict):
53
+ return nested
54
+ return {}
55
+
56
+
57
+ def _chunk_page_content(chunk):
58
+ return _chunk_payload(chunk).get("page_content", "")
59
+
60
+
61
+ def _chunk_metadata(chunk):
62
+ return _chunk_payload(chunk).get("metadata", {})
63
+
64
+
65
+ def find_data_file(filename: str) -> Path:
66
+ explicit = Path(filename)
67
+ if explicit.is_absolute() and explicit.exists():
68
+ return explicit
69
+
70
+ for root in settings.data_search_roots:
71
+ candidate = root / filename
72
+ if candidate.exists():
73
+ return candidate
74
+ raise FileNotFoundError(f"Could not find {filename} in expected locations")
75
+
76
+
77
+ class AppState:
78
+ def __init__(self):
79
+ self.device = "cuda:0" if torch.cuda.is_available() else "cpu"
80
+ self.model_id = settings.model_id
81
+ self.model_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
82
+
83
+ self.index = None
84
+ self.chunks = None
85
+ self.embedding_model = None
86
+ self.model = None
87
+ self.tokenizer = None
88
+ self.startup_timing = {}
89
+
90
+
91
+ state = AppState()
92
+
93
+
94
+ def retrieve_chunks(query: str, k: int) -> list:
95
+ query_embedding = state.embedding_model.encode([[settings.retrieval_instruction, query]])[0]
96
+ query_vector = np.array([query_embedding]).astype("float32")
97
+ _distances, indices = state.index.search(query_vector, k)
98
+ return [state.chunks[i] for i in indices[0]]
99
+
100
+
101
+ def generate_answer(question: str, retrieved_chunks: list) -> str:
102
+ context = ""
103
+ for i, chunk in enumerate(retrieved_chunks):
104
+ context += f"Source {i + 1}:\n{_chunk_page_content(chunk)}\n\n"
105
+
106
+ messages = [
107
+ {
108
+ "role": "system",
109
+ "content": (
110
+ "You are a helpful assistant that answers questions using ONLY the provided sources. "
111
+ "Synthesize information from ALL sources given. "
112
+ "Give a complete and coherent answer. "
113
+ "Do not cut off mid sentence. "
114
+ "If the sources do not contain enough information say so clearly."
115
+ ),
116
+ },
117
+ {
118
+ "role": "user",
119
+ "content": (
120
+ f"Question: {question}\n\n"
121
+ f"{context}"
122
+ "Based on ALL the sources above provide a complete answer to the question."
123
+ ),
124
+ },
125
+ ]
126
+
127
+ text = state.tokenizer.apply_chat_template(
128
+ messages,
129
+ tokenize=False,
130
+ add_generation_prompt=True,
131
+ )
132
+
133
+ inputs = state.tokenizer(
134
+ text,
135
+ return_tensors="pt",
136
+ truncation=True,
137
+ max_length=settings.max_context_tokens,
138
+ ).to(state.device)
139
+
140
+ with torch.no_grad():
141
+ output = state.model.generate(
142
+ **inputs,
143
+ max_new_tokens=settings.max_new_tokens,
144
+ temperature=settings.temperature,
145
+ do_sample=True,
146
+ pad_token_id=state.tokenizer.eos_token_id,
147
+ repetition_penalty=settings.repetition_penalty,
148
+ )
149
+
150
+ generated_tokens = output[0][inputs["input_ids"].shape[1] :]
151
+ return state.tokenizer.decode(generated_tokens, skip_special_tokens=True).strip()
152
+
153
+
154
+ def rag_query(question: str, k: int) -> dict:
155
+ t0 = time.perf_counter()
156
+
157
+ t_retrieve_start = time.perf_counter()
158
+ retrieved = retrieve_chunks(question, k=k)
159
+ retrieval_time = time.perf_counter() - t_retrieve_start
160
+
161
+ t_generate_start = time.perf_counter()
162
+ answer = generate_answer(question, retrieved)
163
+ generation_time = time.perf_counter() - t_generate_start
164
+ total_time = time.perf_counter() - t0
165
+
166
+ sources = [_chunk_metadata(chunk).get("url", "") for chunk in retrieved]
167
+ return {
168
+ "question": question,
169
+ "answer": answer,
170
+ "sources": sources,
171
+ "timing": {
172
+ "retrieval_seconds": retrieval_time,
173
+ "generation_seconds": generation_time,
174
+ "total_seconds": total_time,
175
+ },
176
+ }
177
+
178
+
179
+ def preload() -> dict:
180
+ t0 = time.perf_counter()
181
+
182
+ print(f"Using device : {state.device}")
183
+ if torch.cuda.is_available():
184
+ gpu_name = torch.cuda.get_device_name(0)
185
+ print(f"CUDA available : True ({gpu_name})")
186
+ if torch.cuda.is_bf16_supported():
187
+ state.model_dtype = torch.bfloat16
188
+ else:
189
+ state.model_dtype = torch.float16
190
+ print(f"Model dtype : {state.model_dtype}")
191
+ else:
192
+ print("CUDA available : False")
193
+ state.model_dtype = torch.float32
194
+
195
+ print("Loading vector DB...")
196
+ t_index = time.perf_counter()
197
+ index_path = find_data_file(settings.vector_db_file)
198
+ state.index = faiss.read_index(str(index_path))
199
+ index_time = time.perf_counter() - t_index
200
+ print(f"Index loaded : {state.index.ntotal} vectors")
201
+
202
+ print("Loading chunks...")
203
+ t_chunks = time.perf_counter()
204
+ chunks_path = find_data_file(settings.chunks_file)
205
+ state.chunks = _load_chunks(chunks_path)
206
+ chunks_time = time.perf_counter() - t_chunks
207
+ print(f"Chunks loaded : {len(state.chunks)}")
208
+
209
+ print("Loading embedding model...")
210
+ t_embed = time.perf_counter()
211
+ state.embedding_model = INSTRUCTOR(settings.embedding_model_id)
212
+ if torch.cuda.is_available():
213
+ try:
214
+ state.embedding_model.to(state.device)
215
+ except Exception:
216
+ # Some InstructorEmbedding backends do not expose .to(); keep CPU fallback.
217
+ pass
218
+ embedding_time = time.perf_counter() - t_embed
219
+
220
+ print(f"Loading {settings.model_id}...")
221
+ t_model = time.perf_counter()
222
+ state.model = AutoModelForCausalLM.from_pretrained(
223
+ settings.model_id,
224
+ torch_dtype=state.model_dtype,
225
+ device_map={"": state.device},
226
+ )
227
+ state.tokenizer = AutoTokenizer.from_pretrained(settings.model_id)
228
+ state.model.eval()
229
+ model_time = time.perf_counter() - t_model
230
+
231
+ first_param_device = str(next(state.model.parameters()).device)
232
+ print(f"LLM loaded on : {first_param_device}")
233
+
234
+ total_startup = time.perf_counter() - t0
235
+ state.startup_timing = {
236
+ "index_load_seconds": index_time,
237
+ "chunks_load_seconds": chunks_time,
238
+ "embedding_model_load_seconds": embedding_time,
239
+ "llm_load_seconds": model_time,
240
+ "total_startup_seconds": total_startup,
241
+ }
242
+
243
+ print("RAG API preloaded successfully")
244
+ print(
245
+ f"Startup timing: total={total_startup:.2f}s, index={index_time:.2f}s, "
246
+ f"chunks={chunks_time:.2f}s, embedding={embedding_time:.2f}s, model={model_time:.2f}s"
247
+ )
248
+ return state.startup_timing
schemas.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pydantic import BaseModel, Field
2
+
3
+ from config import settings
4
+
5
+
6
+ class QueryRequest(BaseModel):
7
+ question: str = Field(..., min_length=1, description="User question")
8
+ k: int = Field(
9
+ default=settings.default_top_k,
10
+ ge=settings.min_top_k,
11
+ le=settings.max_top_k,
12
+ description="Top-k chunks to retrieve",
13
+ )
14
+
15
+
16
+ class TimingPayload(BaseModel):
17
+ retrieval_seconds: float
18
+ generation_seconds: float
19
+ total_seconds: float
20
+
21
+
22
+ class QueryResponse(BaseModel):
23
+ question: str
24
+ answer: str
25
+ sources: list[str]
26
+ timing: TimingPayload