LightRT commited on
Commit
bb05158
·
1 Parent(s): b0980a0

Final Formatting

Browse files
Files changed (6) hide show
  1. main.py +0 -6
  2. src/embedding.py +39 -152
  3. src/fix_db.py +0 -25
  4. src/graph.py +1 -2
  5. src/main.py +79 -37
  6. src/retrieval.py +44 -97
main.py DELETED
@@ -1,6 +0,0 @@
1
- def main():
2
- print("Hello from pdf-qa-chatbot!")
3
-
4
-
5
- if __name__ == "__main__":
6
- main()
 
 
 
 
 
 
 
src/embedding.py CHANGED
@@ -1,192 +1,79 @@
1
  from src.ingestion import ingestion_and_chunking
2
-
3
  from qdrant_client import QdrantClient
4
-
5
  from qdrant_client.models import Distance, VectorParams, SparseVectorParams, PointStruct
6
-
7
  from fastembed import TextEmbedding, SparseTextEmbedding
8
-
9
  import uuid
10
-
11
  from dotenv import load_dotenv
12
-
13
  import os
14
 
15
-
16
-
17
  load_dotenv()
18
 
19
  qdrant_api_key = os.getenv("QDRANT_API_KEY")
20
-
21
  qdrant_url = os.getenv("QDRANT_URL")
22
 
23
 
24
-
25
  def upload_file(file_path: str, user_id: str, collection_name="pdf_rag"):
26
-
27
-
28
-
29
  client = QdrantClient(url=qdrant_url, api_key=qdrant_api_key)
30
 
31
-
32
-
33
  dense_model = TextEmbedding(model_name="sentence-transformers/all-MiniLM-L6-v2")
34
-
35
  sparse_model = SparseTextEmbedding(model_name="Qdrant/bm25")
36
 
37
-
38
-
39
- # 1. ONLY the database creation should be inside this IF block
40
-
41
  if not client.collection_exists(collection_name):
42
-
43
  client.create_collection(
44
-
45
  collection_name=collection_name,
46
-
47
  vectors_config={
48
-
49
  "dense": VectorParams(size=384, distance=Distance.COSINE)
50
-
51
  },
52
-
53
  sparse_vectors_config={
54
-
55
  "sparse": SparseVectorParams()
56
-
57
  }
58
-
59
  )
 
 
60
 
 
 
61
 
 
 
62
 
63
- # 2. EVERYTHING ELSE MUST BE UN-INDENTED SO IT RUNS EVERY TIME
64
-
65
- try:
66
-
67
- docs = ingestion_and_chunking(file_path)
68
-
69
- texts = [doc.page_content for doc in docs]
70
-
71
-
72
-
73
- dense_vectors = list(dense_model.embed(texts))
74
-
75
- sparse_vectors = list(sparse_model.embed(texts))
76
-
77
-
78
-
79
- points = []
80
 
81
- file_id = str(uuid.uuid4())
 
 
 
 
82
 
 
83
 
84
-
85
- for i, doc in enumerate(docs):
86
-
87
- # 1. Convert numpy array to standard Python list
88
-
89
- dense_vec = dense_vectors[i].tolist()
90
-
91
-
92
-
93
- # 2. Extract indices and values from FastEmbed's custom object
94
-
95
- sparse_emb = sparse_vectors[i]
96
-
97
- sparse_vec = {
98
-
99
- "indices": sparse_emb.indices.tolist(),
100
-
101
- "values": sparse_emb.values.tolist()
102
-
103
  }
 
104
 
105
- chunk_id = str(uuid.uuid4())
106
-
107
-
108
-
109
- point = PointStruct(
110
-
111
- id=chunk_id, # Reusing the same file_id so all chunks tie back to one file
112
-
113
- vector={
114
-
115
- 'dense': dense_vec,
116
-
117
- 'sparse': sparse_vec
118
-
119
- },
120
-
121
- payload={
122
-
123
- 'user_id': user_id,
124
-
125
- 'file_id': file_id,
126
-
127
- 'text': doc.page_content,
128
-
129
- "source": doc.metadata.get("source"),
130
-
131
- "pages": doc.metadata.get("pages"),
132
-
133
- "section": doc.metadata.get("section")
134
-
135
- }
136
-
137
- )
138
-
139
- points.append(point)
140
-
141
-
142
-
143
- # (Optional but safe) Tell Qdrant to index it just in case
144
-
145
- try:
146
-
147
- client.create_payload_index(
148
-
149
- collection_name=collection_name,
150
-
151
- field_name="user_id",
152
-
153
- field_schema="keyword"
154
-
155
- )
156
-
157
- except Exception:
158
-
159
- pass
160
-
161
-
162
-
163
- # Send to database
164
-
165
- client.upsert(collection_name=collection_name, points=points)
166
-
167
-
168
-
169
- # 3. THE LOUD TERMINAL ANNOUNCEMENT
170
-
171
- print("\n" + "="*60, flush=True)
172
-
173
- print(f"✅ SUCCESS: PDF FULLY PROCESSED FOR USER {user_id}", flush=True)
174
-
175
- print("✅ YOU CAN NOW ASK QUESTIONS IN STREAMLIT!", flush=True)
176
-
177
- print("="*60 + "\n", flush=True)
178
-
179
-
180
-
181
- except Exception as e:
182
-
183
- # 4. IF IT CRASHES, SCREAM THE ERROR TO THE TERMINAL
184
-
185
- print("\n" + "!"*60, flush=True)
186
-
187
- print(f"❌ UPLOAD FAILED SILENTLY IN BACKGROUND:", flush=True)
188
-
189
- print(f"{str(e)}", flush=True)
190
 
191
- print("!"*60 + "\n", flush=True)
 
 
 
 
 
 
 
192
 
 
 
 
1
  from src.ingestion import ingestion_and_chunking
 
2
  from qdrant_client import QdrantClient
 
3
  from qdrant_client.models import Distance, VectorParams, SparseVectorParams, PointStruct
 
4
  from fastembed import TextEmbedding, SparseTextEmbedding
 
5
  import uuid
 
6
  from dotenv import load_dotenv
 
7
  import os
8
 
 
 
9
  load_dotenv()
10
 
11
  qdrant_api_key = os.getenv("QDRANT_API_KEY")
 
12
  qdrant_url = os.getenv("QDRANT_URL")
13
 
14
 
 
15
  def upload_file(file_path: str, user_id: str, collection_name="pdf_rag"):
 
 
 
16
  client = QdrantClient(url=qdrant_url, api_key=qdrant_api_key)
17
 
 
 
18
  dense_model = TextEmbedding(model_name="sentence-transformers/all-MiniLM-L6-v2")
 
19
  sparse_model = SparseTextEmbedding(model_name="Qdrant/bm25")
20
 
 
 
 
 
21
  if not client.collection_exists(collection_name):
 
22
  client.create_collection(
 
23
  collection_name=collection_name,
 
24
  vectors_config={
 
25
  "dense": VectorParams(size=384, distance=Distance.COSINE)
 
26
  },
 
27
  sparse_vectors_config={
 
28
  "sparse": SparseVectorParams()
 
29
  }
 
30
  )
31
+ docs = ingestion_and_chunking(file_path)
32
+ texts = [doc.page_content for doc in docs]
33
 
34
+ dense_vectors = list(dense_model.embed(texts))
35
+ sparse_vectors = list(sparse_model.embed(texts))
36
 
37
+ points = []
38
+ file_id = str(uuid.uuid4())
39
 
40
+ for i, doc in enumerate(docs):
41
+ dense_vec = dense_vectors[i].tolist()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
 
43
+ sparse_emb = sparse_vectors[i]
44
+ sparse_vec = {
45
+ "indices": sparse_emb.indices.tolist(),
46
+ "values": sparse_emb.values.tolist()
47
+ }
48
 
49
+ chunk_id = str(uuid.uuid4())
50
 
51
+ point = PointStruct(
52
+ id=chunk_id,
53
+ vector={
54
+ "dense": dense_vec,
55
+ "sparse": sparse_vec
56
+ },
57
+ payload={
58
+ "user_id": user_id,
59
+ "file_id": file_id,
60
+ "text": doc.page_content,
61
+ "source": doc.metadata.get("source"),
62
+ "pages": doc.metadata.get("pages"),
63
+ "section": doc.metadata.get("section")
 
 
 
 
 
 
64
  }
65
+ )
66
 
67
+ points.append(point)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
 
69
+ try:
70
+ client.create_payload_index(
71
+ collection_name=collection_name,
72
+ field_name="user_id",
73
+ field_schema="keyword"
74
+ )
75
+ except Exception:
76
+ pass
77
 
78
+ client.upsert(collection_name=collection_name, points=points)
79
+
src/fix_db.py DELETED
@@ -1,25 +0,0 @@
1
- import os
2
- from qdrant_client import QdrantClient
3
- from dotenv import load_dotenv
4
-
5
- load_dotenv()
6
-
7
- client = QdrantClient(
8
- url=os.getenv("QDRANT_URL"),
9
- api_key=os.getenv("QDRANT_API_KEY")
10
- )
11
-
12
- # LOOK AT YOUR retrieval.py FILE AND COPY THE EXACT COLLECTION NAME HERE
13
- COLLECTION_NAME = "pdf_rag"
14
-
15
- print(f"Attempting to build index for '{COLLECTION_NAME}'...")
16
-
17
- try:
18
- client.create_payload_index(
19
- collection_name=COLLECTION_NAME,
20
- field_name="user_id",
21
- field_schema="keyword"
22
- )
23
- print("✅ Index built successfully! Qdrant is ready.")
24
- except Exception as e:
25
- print(f"❌ FAILED: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/graph.py CHANGED
@@ -189,5 +189,4 @@ workflow.add_conditional_edges(
189
  routing,
190
  {"web_search_node": "web_search_node",
191
  "END": END})
192
- workflow.add_edge("web_search_node" , "answer_node")
193
-
 
189
  routing,
190
  {"web_search_node": "web_search_node",
191
  "END": END})
192
+ workflow.add_edge("web_search_node" , "answer_node")
 
src/main.py CHANGED
@@ -1,12 +1,11 @@
1
- from fastapi import FastAPI , HTTPException , UploadFile, File, BackgroundTasks , Form
2
- from pydantic import BaseModel , Field
3
  import os
4
  from dotenv import load_dotenv
5
  from src.graph import workflow
6
  from src.embedding import upload_file
7
  import shutil
8
  from langgraph.checkpoint.postgres import PostgresSaver
9
- from psycopg_pool import ConnectionPool
10
 
11
  load_dotenv()
12
 
@@ -16,40 +15,60 @@ app = FastAPI(
16
  version="1.0.0"
17
  )
18
 
 
19
  class ChatRequest(BaseModel):
20
- message: str = Field(..., description="The raw message string from the user.")
21
- user_id: str = Field(..., description="The unique identifier for the tenant context.")
22
- thread_id: str = Field(..., description="The unique session ID tracking the short-term chat history.")
 
 
 
 
 
 
 
 
 
 
23
 
24
- @app.post("/chat", summary="Return an answer using the RAG backend to the user query.")
 
 
 
25
  async def chat_endpoint(request: ChatRequest):
26
  try:
27
- config = {'configurable': {'thread_id': request.thread_id}}
 
 
 
 
 
28
  initial_state = {
29
  "messages": [("user", request.message)],
30
  "user_id": request.user_id
31
  }
32
-
33
- # 1. Grab the database URL
34
  db_uri = os.getenv("DATABASE_URI")
35
-
36
- # 2. Open a fresh, guaranteed-alive connection to Postgres
37
  with PostgresSaver.from_conn_string(db_uri) as checkpointer:
38
-
39
- # (Optional) Ensure tables exist
40
  checkpointer.setup()
41
-
42
- # 3. Compile the LangGraph blueprint with our fresh memory connection
43
- agent = workflow.compile(checkpointer=checkpointer)
44
-
45
- # 4. Run the AI pipeline
46
- result = agent.invoke(initial_state, config=config)
47
-
48
- # 5. Extract the AI's final answer
 
 
49
  output_messages = result.get("messages", [])
 
50
  if not output_messages:
51
- raise ValueError("No messages returned from the graph.")
52
-
 
 
53
  ai_response = output_messages[-1].content
54
 
55
  return {
@@ -57,28 +76,51 @@ async def chat_endpoint(request: ChatRequest):
57
  "thread_id": request.thread_id,
58
  "response": ai_response
59
  }
60
-
61
  except Exception as e:
62
- print(f"Backend Error: {str(e)}")
63
- raise HTTPException(status_code=500, detail=f"Agent Processing Error: {str(e)}")
64
-
 
 
 
 
 
65
  UPLOAD_DIR = "data/uploads"
 
66
  os.makedirs(UPLOAD_DIR, exist_ok=True)
67
 
68
- @app.post("/upload", summary="Upload a PDF and process its embeddings in the background")
 
 
 
 
69
  async def upload_pdf(
70
- background_tasks: BackgroundTasks,
71
  file: UploadFile = File(...),
72
- user_id : str = Form(...)
73
  ):
74
- local_file_path = os.path.join(UPLOAD_DIR, file.filename)
75
-
 
 
 
76
  with open(local_file_path, "wb") as buffer:
77
- shutil.copyfileobj(file.file, buffer)
78
-
79
- background_tasks.add_task(upload_file, local_file_path, user_id)
 
 
 
 
 
 
 
80
 
81
  return {
82
  "status": "success",
83
- "message": f"'{file.filename}' received successfully. Ingestion pipeline started in the background."
 
 
 
84
  }
 
1
+ from fastapi import FastAPI, HTTPException, UploadFile, File, BackgroundTasks, Form
2
+ from pydantic import BaseModel, Field
3
  import os
4
  from dotenv import load_dotenv
5
  from src.graph import workflow
6
  from src.embedding import upload_file
7
  import shutil
8
  from langgraph.checkpoint.postgres import PostgresSaver
 
9
 
10
  load_dotenv()
11
 
 
15
  version="1.0.0"
16
  )
17
 
18
+
19
  class ChatRequest(BaseModel):
20
+ message: str = Field(
21
+ ...,
22
+ description="The raw message string from the user."
23
+ )
24
+ user_id: str = Field(
25
+ ...,
26
+ description="The unique identifier for the tenant context."
27
+ )
28
+ thread_id: str = Field(
29
+ ...,
30
+ description="The unique session ID tracking the short-term chat history."
31
+ )
32
+
33
 
34
+ @app.post(
35
+ "/chat",
36
+ summary="Return an answer using the RAG backend to the user query."
37
+ )
38
  async def chat_endpoint(request: ChatRequest):
39
  try:
40
+ config = {
41
+ "configurable": {
42
+ "thread_id": request.thread_id
43
+ }
44
+ }
45
+
46
  initial_state = {
47
  "messages": [("user", request.message)],
48
  "user_id": request.user_id
49
  }
50
+
 
51
  db_uri = os.getenv("DATABASE_URI")
52
+
 
53
  with PostgresSaver.from_conn_string(db_uri) as checkpointer:
 
 
54
  checkpointer.setup()
55
+
56
+ agent = workflow.compile(
57
+ checkpointer=checkpointer
58
+ )
59
+
60
+ result = agent.invoke(
61
+ initial_state,
62
+ config=config
63
+ )
64
+
65
  output_messages = result.get("messages", [])
66
+
67
  if not output_messages:
68
+ raise ValueError(
69
+ "No messages returned from the graph."
70
+ )
71
+
72
  ai_response = output_messages[-1].content
73
 
74
  return {
 
76
  "thread_id": request.thread_id,
77
  "response": ai_response
78
  }
79
+
80
  except Exception as e:
81
+ print(f"Backend Error: {str(e)}")
82
+
83
+ raise HTTPException(
84
+ status_code=500,
85
+ detail=f"Agent Processing Error: {str(e)}"
86
+ )
87
+
88
+
89
  UPLOAD_DIR = "data/uploads"
90
+
91
  os.makedirs(UPLOAD_DIR, exist_ok=True)
92
 
93
+
94
+ @app.post(
95
+ "/upload",
96
+ summary="Upload a PDF and process its embeddings in the background"
97
+ )
98
  async def upload_pdf(
99
+ background_tasks: BackgroundTasks,
100
  file: UploadFile = File(...),
101
+ user_id: str = Form(...)
102
  ):
103
+ local_file_path = os.path.join(
104
+ UPLOAD_DIR,
105
+ file.filename
106
+ )
107
+
108
  with open(local_file_path, "wb") as buffer:
109
+ shutil.copyfileobj(
110
+ file.file,
111
+ buffer
112
+ )
113
+
114
+ background_tasks.add_task(
115
+ upload_file,
116
+ local_file_path,
117
+ user_id
118
+ )
119
 
120
  return {
121
  "status": "success",
122
+ "message": (
123
+ f"'{file.filename}' received successfully. "
124
+ "Ingestion pipeline started in the background."
125
+ )
126
  }
src/retrieval.py CHANGED
@@ -1,138 +1,85 @@
1
  import os
2
-
3
  from dotenv import load_dotenv
4
-
5
  from qdrant_client import QdrantClient
6
-
7
  from qdrant_client import models
8
-
9
  from fastembed import TextEmbedding, SparseTextEmbedding
10
-
11
  from fastembed.rerank.cross_encoder import TextCrossEncoder
12
 
13
-
14
-
15
  load_dotenv()
16
 
17
-
18
-
19
  qdrant_api_key = os.getenv("QDRANT_API_KEY")
20
-
21
  qdrant_url = os.getenv("QDRANT_URL")
22
 
23
 
24
-
25
- class Retriever() :
26
-
27
- def __init__(self , collection_name = 'pdf_rag') :
28
-
29
  self.collection_name = collection_name
30
-
31
- self.client = QdrantClient(url=qdrant_url , api_key=qdrant_api_key)
32
-
33
-
34
 
35
  self.dense_model = TextEmbedding(model_name="sentence-transformers/all-MiniLM-L6-v2")
36
-
37
  self.sparse_model = SparseTextEmbedding(model_name="Qdrant/bm25")
38
 
39
-
40
-
41
  self.reranker = TextCrossEncoder(model_name="Xenova/ms-marco-MiniLM-L-6-v2")
42
 
43
-
44
-
45
-
46
-
47
- def retrieve(self , query : str , user_id : str) :
48
-
49
-
50
-
51
  dense_query_vector = list(self.dense_model.embed([query]))[0]
52
 
53
-
54
-
55
  sparse_query = list(self.sparse_model.embed([query]))[0]
56
-
57
- sparse_query_vector = models.SparseVector(indices=sparse_query.indices,
58
-
59
- values=sparse_query.values)
60
-
61
-
62
-
63
- user_filter = models.Filter(must=[models.FieldCondition(key="user_id" , match=models.MatchValue(value=user_id))])
64
-
65
-
66
-
67
- results = self.client.query_points(collection_name=self.collection_name,
68
-
69
- prefetch=[models.Prefetch(
70
-
71
- query=dense_query_vector,
72
-
73
- limit=20,
74
-
75
- using='dense',
76
-
77
- filter=user_filter
78
-
79
- ),
80
-
81
- models.Prefetch(
82
-
83
- query=sparse_query_vector,
84
-
85
- using='sparse',
86
-
87
- limit=20,
88
-
89
- filter=user_filter
90
-
91
- )],
92
-
93
- query=models.FusionQuery(fusion=models.Fusion.RRF),
94
-
95
- limit=20)
96
-
97
-
98
-
99
- texts = [point.payload.get('text' , '') for point in results.points]
100
-
101
-
102
 
103
  rerank_scores = list(self.reranker.rerank(query, texts))
104
 
105
-
106
-
107
  reranked_results = []
108
 
109
  for point, score in zip(results.points, rerank_scores):
110
-
111
  reranked_results.append({
112
-
113
  "text": point.payload.get("text"),
114
-
115
  "source": point.payload.get("source"),
116
-
117
  "pages": point.payload.get("pages"),
118
-
119
  "section": point.payload.get("section"),
120
-
121
  "original_qdrant_score": point.score,
122
-
123
  "rerank_score": float(score)
124
-
125
  })
126
 
127
-
128
-
129
- reranked_results.sort(key=lambda x: x["rerank_score"], reverse=True)
130
-
131
-
132
 
133
  final_top_results = reranked_results[:5]
134
 
135
-
136
-
137
- return final_top_results
138
-
 
1
  import os
 
2
  from dotenv import load_dotenv
 
3
  from qdrant_client import QdrantClient
 
4
  from qdrant_client import models
 
5
  from fastembed import TextEmbedding, SparseTextEmbedding
 
6
  from fastembed.rerank.cross_encoder import TextCrossEncoder
7
 
 
 
8
  load_dotenv()
9
 
 
 
10
  qdrant_api_key = os.getenv("QDRANT_API_KEY")
 
11
  qdrant_url = os.getenv("QDRANT_URL")
12
 
13
 
14
+ class Retriever:
15
+ def __init__(self, collection_name="pdf_rag"):
 
 
 
16
  self.collection_name = collection_name
17
+ self.client = QdrantClient(url=qdrant_url, api_key=qdrant_api_key)
 
 
 
18
 
19
  self.dense_model = TextEmbedding(model_name="sentence-transformers/all-MiniLM-L6-v2")
 
20
  self.sparse_model = SparseTextEmbedding(model_name="Qdrant/bm25")
21
 
 
 
22
  self.reranker = TextCrossEncoder(model_name="Xenova/ms-marco-MiniLM-L-6-v2")
23
 
24
+ def retrieve(self, query: str, user_id: str):
 
 
 
 
 
 
 
25
  dense_query_vector = list(self.dense_model.embed([query]))[0]
26
 
 
 
27
  sparse_query = list(self.sparse_model.embed([query]))[0]
28
+ sparse_query_vector = models.SparseVector(
29
+ indices=sparse_query.indices,
30
+ values=sparse_query.values
31
+ )
32
+
33
+ user_filter = models.Filter(
34
+ must=[
35
+ models.FieldCondition(
36
+ key="user_id",
37
+ match=models.MatchValue(value=user_id)
38
+ )
39
+ ]
40
+ )
41
+
42
+ results = self.client.query_points(
43
+ collection_name=self.collection_name,
44
+ prefetch=[
45
+ models.Prefetch(
46
+ query=dense_query_vector,
47
+ limit=20,
48
+ using="dense",
49
+ filter=user_filter
50
+ ),
51
+ models.Prefetch(
52
+ query=sparse_query_vector,
53
+ using="sparse",
54
+ limit=20,
55
+ filter=user_filter
56
+ )
57
+ ],
58
+ query=models.FusionQuery(fusion=models.Fusion.RRF),
59
+ limit=20
60
+ )
61
+
62
+ texts = [
63
+ point.payload.get("text", "")
64
+ for point in results.points
65
+ ]
 
 
 
 
 
 
 
 
66
 
67
  rerank_scores = list(self.reranker.rerank(query, texts))
68
 
 
 
69
  reranked_results = []
70
 
71
  for point, score in zip(results.points, rerank_scores):
 
72
  reranked_results.append({
 
73
  "text": point.payload.get("text"),
 
74
  "source": point.payload.get("source"),
 
75
  "pages": point.payload.get("pages"),
 
76
  "section": point.payload.get("section"),
 
77
  "original_qdrant_score": point.score,
 
78
  "rerank_score": float(score)
 
79
  })
80
 
81
+ reranked_results.sort(key=lambda x: x["rerank_score"],reverse=True)
 
 
 
 
82
 
83
  final_top_results = reranked_results[:5]
84
 
85
+ return final_top_results