LightRT commited on
Commit
77d7fca
·
1 Parent(s): 86cca3c

Final Changes

Browse files
app.py CHANGED
@@ -33,7 +33,7 @@ with st.sidebar:
33
  # Send the POST request to your local FastAPI server
34
  try:
35
  response = requests.post(
36
- "http://127.0.0.1:8000",
37
  files=files,
38
  data=payload_data
39
  )
@@ -74,7 +74,7 @@ if prompt := st.chat_input("Ask a question about your documents..."):
74
 
75
  try:
76
  # Send the question to your LangGraph backend
77
- chat_response = requests.post("http://127.0.0.1:8000", json=payload)
78
 
79
  if chat_response.status_code == 200:
80
  # Extract the answer from the JSON response
 
33
  # Send the POST request to your local FastAPI server
34
  try:
35
  response = requests.post(
36
+ "http://127.0.0.1:8000/upload",
37
  files=files,
38
  data=payload_data
39
  )
 
74
 
75
  try:
76
  # Send the question to your LangGraph backend
77
+ chat_response = requests.post("http://127.0.0.1:8000/chat", json=payload)
78
 
79
  if chat_response.status_code == 200:
80
  # Extract the answer from the JSON response
src/__pycache__/embedding.cpython-312.pyc CHANGED
Binary files a/src/__pycache__/embedding.cpython-312.pyc and b/src/__pycache__/embedding.cpython-312.pyc differ
 
src/__pycache__/ingestion.cpython-312.pyc CHANGED
Binary files a/src/__pycache__/ingestion.cpython-312.pyc and b/src/__pycache__/ingestion.cpython-312.pyc differ
 
src/__pycache__/retrieval.cpython-312.pyc CHANGED
Binary files a/src/__pycache__/retrieval.cpython-312.pyc and b/src/__pycache__/retrieval.cpython-312.pyc differ
 
src/embedding.py CHANGED
@@ -1,92 +1,192 @@
1
  from src.ingestion import ingestion_and_chunking
 
2
  from qdrant_client import QdrantClient
 
3
  from qdrant_client.models import Distance, VectorParams, SparseVectorParams, PointStruct
4
- from fastembed import SparseTextEmbedding
 
 
5
  import uuid
 
6
  from dotenv import load_dotenv
 
7
  import os
8
- from langchain_community.embeddings import HuggingFaceInferenceAPIEmbeddings
 
9
 
10
  load_dotenv()
 
11
  qdrant_api_key = os.getenv("QDRANT_API_KEY")
 
12
  qdrant_url = os.getenv("QDRANT_URL")
13
- hf_token = os.getenv("HF_TOKEN")
14
 
15
- def upload_file(file_path: str, user_id: str, collection_name="pdf_rag_chat"):
 
 
 
 
16
 
17
  client = QdrantClient(url=qdrant_url, api_key=qdrant_api_key)
18
 
19
- dense_model = HuggingFaceInferenceAPIEmbeddings(
20
- api_key=hf_token,
21
- model_name="sentence-transformers/all-MiniLM-L6-v2")
 
22
  sparse_model = SparseTextEmbedding(model_name="Qdrant/bm25")
23
 
 
 
24
  # 1. ONLY the database creation should be inside this IF block
 
25
  if not client.collection_exists(collection_name):
 
26
  client.create_collection(
 
27
  collection_name=collection_name,
 
28
  vectors_config={
 
29
  "dense": VectorParams(size=384, distance=Distance.COSINE)
 
30
  },
 
31
  sparse_vectors_config={
 
32
  "sparse": SparseVectorParams()
 
33
  }
 
34
  )
35
 
 
 
36
  # 2. EVERYTHING ELSE MUST BE UN-INDENTED SO IT RUNS EVERY TIME
 
37
  try:
 
38
  docs = ingestion_and_chunking(file_path)
 
39
  texts = [doc.page_content for doc in docs]
40
 
41
- dense_vectors = dense_model.embed_documents(texts)
 
 
 
42
  sparse_vectors = list(sparse_model.embed(texts))
43
 
 
 
44
  points = []
 
45
  file_id = str(uuid.uuid4())
46
 
 
 
47
  for i, doc in enumerate(docs):
 
48
  # 1. Convert numpy array to standard Python list
49
- dense_vec = dense_vectors[i]
50
-
 
 
 
51
  # 2. Extract indices and values from FastEmbed's custom object
 
52
  sparse_emb = sparse_vectors[i]
 
53
  sparse_vec = {
 
54
  "indices": sparse_emb.indices.tolist(),
 
55
  "values": sparse_emb.values.tolist()
 
56
  }
 
57
  chunk_id = str(uuid.uuid4())
58
 
 
 
59
  point = PointStruct(
 
60
  id=chunk_id, # Reusing the same file_id so all chunks tie back to one file
 
61
  vector={
 
62
  'dense': dense_vec,
 
63
  'sparse': sparse_vec
 
64
  },
 
65
  payload={
 
66
  'user_id': user_id,
 
67
  'file_id': file_id,
 
68
  'text': doc.page_content,
 
69
  "source": doc.metadata.get("source"),
 
70
  "pages": doc.metadata.get("pages"),
 
71
  "section": doc.metadata.get("section")
 
72
  }
 
73
  )
 
74
  points.append(point)
75
 
 
 
76
  # (Optional but safe) Tell Qdrant to index it just in case
 
77
  try:
 
78
  client.create_payload_index(
79
- collection_name=collection_name,
 
 
80
  field_name="user_id",
 
81
  field_schema="keyword"
 
82
  )
 
83
  except Exception:
 
84
  pass
85
 
 
 
86
  # Send to database
 
87
  client.upsert(collection_name=collection_name, points=points)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
  except Exception as e:
 
 
 
89
  print("\n" + "!"*60, flush=True)
 
90
  print(f"❌ UPLOAD FAILED SILENTLY IN BACKGROUND:", flush=True)
 
91
  print(f"{str(e)}", flush=True)
92
- print("!"*60 + "\n", flush=True)
 
 
 
1
  from src.ingestion import ingestion_and_chunking
2
+
3
  from qdrant_client import QdrantClient
4
+
5
  from qdrant_client.models import Distance, VectorParams, SparseVectorParams, PointStruct
6
+
7
+ from fastembed import TextEmbedding, SparseTextEmbedding
8
+
9
  import uuid
10
+
11
  from dotenv import load_dotenv
12
+
13
  import os
14
+
15
+
16
 
17
  load_dotenv()
18
+
19
  qdrant_api_key = os.getenv("QDRANT_API_KEY")
20
+
21
  qdrant_url = os.getenv("QDRANT_URL")
 
22
 
23
+
24
+
25
+ def upload_file(file_path: str, user_id: str, collection_name="pdf_rag"):
26
+
27
+
28
 
29
  client = QdrantClient(url=qdrant_url, api_key=qdrant_api_key)
30
 
31
+
32
+
33
+ dense_model = TextEmbedding(model_name="sentence-transformers/all-MiniLM-L6-v2")
34
+
35
  sparse_model = SparseTextEmbedding(model_name="Qdrant/bm25")
36
 
37
+
38
+
39
  # 1. ONLY the database creation should be inside this IF block
40
+
41
  if not client.collection_exists(collection_name):
42
+
43
  client.create_collection(
44
+
45
  collection_name=collection_name,
46
+
47
  vectors_config={
48
+
49
  "dense": VectorParams(size=384, distance=Distance.COSINE)
50
+
51
  },
52
+
53
  sparse_vectors_config={
54
+
55
  "sparse": SparseVectorParams()
56
+
57
  }
58
+
59
  )
60
 
61
+
62
+
63
  # 2. EVERYTHING ELSE MUST BE UN-INDENTED SO IT RUNS EVERY TIME
64
+
65
  try:
66
+
67
  docs = ingestion_and_chunking(file_path)
68
+
69
  texts = [doc.page_content for doc in docs]
70
 
71
+
72
+
73
+ dense_vectors = list(dense_model.embed(texts))
74
+
75
  sparse_vectors = list(sparse_model.embed(texts))
76
 
77
+
78
+
79
  points = []
80
+
81
  file_id = str(uuid.uuid4())
82
 
83
+
84
+
85
  for i, doc in enumerate(docs):
86
+
87
  # 1. Convert numpy array to standard Python list
88
+
89
+ dense_vec = dense_vectors[i].tolist()
90
+
91
+
92
+
93
  # 2. Extract indices and values from FastEmbed's custom object
94
+
95
  sparse_emb = sparse_vectors[i]
96
+
97
  sparse_vec = {
98
+
99
  "indices": sparse_emb.indices.tolist(),
100
+
101
  "values": sparse_emb.values.tolist()
102
+
103
  }
104
+
105
  chunk_id = str(uuid.uuid4())
106
 
107
+
108
+
109
  point = PointStruct(
110
+
111
  id=chunk_id, # Reusing the same file_id so all chunks tie back to one file
112
+
113
  vector={
114
+
115
  'dense': dense_vec,
116
+
117
  'sparse': sparse_vec
118
+
119
  },
120
+
121
  payload={
122
+
123
  'user_id': user_id,
124
+
125
  'file_id': file_id,
126
+
127
  'text': doc.page_content,
128
+
129
  "source": doc.metadata.get("source"),
130
+
131
  "pages": doc.metadata.get("pages"),
132
+
133
  "section": doc.metadata.get("section")
134
+
135
  }
136
+
137
  )
138
+
139
  points.append(point)
140
 
141
+
142
+
143
  # (Optional but safe) Tell Qdrant to index it just in case
144
+
145
  try:
146
+
147
  client.create_payload_index(
148
+
149
+ collection_name=collection_name,
150
+
151
  field_name="user_id",
152
+
153
  field_schema="keyword"
154
+
155
  )
156
+
157
  except Exception:
158
+
159
  pass
160
 
161
+
162
+
163
  # Send to database
164
+
165
  client.upsert(collection_name=collection_name, points=points)
166
+
167
+
168
+
169
+ # 3. THE LOUD TERMINAL ANNOUNCEMENT
170
+
171
+ print("\n" + "="*60, flush=True)
172
+
173
+ print(f"✅ SUCCESS: PDF FULLY PROCESSED FOR USER {user_id}", flush=True)
174
+
175
+ print("✅ YOU CAN NOW ASK QUESTIONS IN STREAMLIT!", flush=True)
176
+
177
+ print("="*60 + "\n", flush=True)
178
+
179
+
180
+
181
  except Exception as e:
182
+
183
+ # 4. IF IT CRASHES, SCREAM THE ERROR TO THE TERMINAL
184
+
185
  print("\n" + "!"*60, flush=True)
186
+
187
  print(f"❌ UPLOAD FAILED SILENTLY IN BACKGROUND:", flush=True)
188
+
189
  print(f"{str(e)}", flush=True)
190
+
191
+ print("!"*60 + "\n", flush=True)
192
+
src/ingestion.py CHANGED
@@ -2,8 +2,6 @@ from docling.document_converter import DocumentConverter
2
  from docling.chunking import HybridChunker
3
  from transformers import AutoTokenizer
4
  from langchain_core.documents import Document
5
- from docling_core.transforms.chunker.tokenizer.openai import OpenAITokenizer
6
-
7
 
8
  def ingestion_and_chunking(file_path : str) :
9
 
 
2
  from docling.chunking import HybridChunker
3
  from transformers import AutoTokenizer
4
  from langchain_core.documents import Document
 
 
5
 
6
  def ingestion_and_chunking(file_path : str) :
7
 
src/retrieval.py CHANGED
@@ -1,95 +1,138 @@
1
  import os
2
- import requests
3
  from dotenv import load_dotenv
 
4
  from qdrant_client import QdrantClient
 
5
  from qdrant_client import models
6
- from fastembed import SparseTextEmbedding
7
- from langchain_community.embeddings import HuggingFaceInferenceAPIEmbeddings
 
 
 
 
8
 
9
  load_dotenv()
10
 
 
 
11
  qdrant_api_key = os.getenv("QDRANT_API_KEY")
 
12
  qdrant_url = os.getenv("QDRANT_URL")
13
- hf_token = os.getenv("HF_TOKEN")
 
14
 
15
  class Retriever() :
16
- def __init__(self , collection_name = 'pdf_rag_v3') :
 
 
17
  self.collection_name = collection_name
 
18
  self.client = QdrantClient(url=qdrant_url , api_key=qdrant_api_key)
19
 
20
- # 🚨 THE FIX: Do NOT load models here. Let the server boot fast and light.
21
- self.dense_model = None
22
- self.sparse_model = None
23
-
24
- def cloud_rerank(self, query, texts):
25
- API_URL = "https://api-inference.huggingface.co/models/cross-encoder/ms-marco-MiniLM-L-6-v2"
26
- headers = {"Authorization": f"Bearer {hf_token}"}
27
- payload = {
28
- "inputs": {
29
- "source_sentence": query,
30
- "sentences": texts
31
- }
32
- }
33
- try:
34
- response = requests.post(API_URL, headers=headers, json=payload)
35
- if response.status_code == 200:
36
- return response.json()
37
- except Exception as e:
38
- print(f"Cloud reranker failed: {e}")
39
- pass
40
-
41
- return [0.0] * len(texts)
42
 
43
 
44
  def retrieve(self , query : str , user_id : str) :
45
- # 🚨 THE FIX: Lazy Load. Only turn the models on the very first time someone asks a question!
46
- if self.dense_model is None:
47
- self.dense_model = HuggingFaceInferenceAPIEmbeddings(
48
- api_key=hf_token,
49
- model_name="sentence-transformers/all-MiniLM-L6-v2"
50
- )
51
- if self.sparse_model is None:
52
- self.sparse_model = SparseTextEmbedding(model_name="Qdrant/bm25")
53
 
54
- dense_query_vector = self.dense_model.embed_query(query)
 
 
 
 
55
 
56
  sparse_query = list(self.sparse_model.embed([query]))[0]
 
57
  sparse_query_vector = models.SparseVector(indices=sparse_query.indices,
 
58
  values=sparse_query.values)
59
-
 
 
60
  user_filter = models.Filter(must=[models.FieldCondition(key="user_id" , match=models.MatchValue(value=user_id))])
61
 
 
 
62
  results = self.client.query_points(collection_name=self.collection_name,
 
63
  prefetch=[models.Prefetch(
 
64
  query=dense_query_vector,
 
65
  limit=20,
 
66
  using='dense',
 
67
  filter=user_filter
 
68
  ),
 
69
  models.Prefetch(
 
70
  query=sparse_query_vector,
 
71
  using='sparse',
 
72
  limit=20,
 
73
  filter=user_filter
 
74
  )],
 
75
  query=models.FusionQuery(fusion=models.Fusion.RRF),
 
76
  limit=20)
77
-
 
 
78
  texts = [point.payload.get('text' , '') for point in results.points]
79
 
80
- rerank_scores = self.cloud_rerank(query, texts)
 
 
 
 
81
 
82
  reranked_results = []
 
83
  for point, score in zip(results.points, rerank_scores):
 
84
  reranked_results.append({
 
85
  "text": point.payload.get("text"),
 
86
  "source": point.payload.get("source"),
 
87
  "pages": point.payload.get("pages"),
 
88
  "section": point.payload.get("section"),
 
89
  "original_qdrant_score": point.score,
 
90
  "rerank_score": float(score)
 
91
  })
92
 
 
 
93
  reranked_results.sort(key=lambda x: x["rerank_score"], reverse=True)
94
 
95
- return reranked_results[:5]
 
 
 
 
 
 
 
 
1
  import os
2
+
3
  from dotenv import load_dotenv
4
+
5
  from qdrant_client import QdrantClient
6
+
7
  from qdrant_client import models
8
+
9
+ from fastembed import TextEmbedding, SparseTextEmbedding
10
+
11
+ from fastembed.rerank.cross_encoder import TextCrossEncoder
12
+
13
+
14
 
15
  load_dotenv()
16
 
17
+
18
+
19
  qdrant_api_key = os.getenv("QDRANT_API_KEY")
20
+
21
  qdrant_url = os.getenv("QDRANT_URL")
22
+
23
+
24
 
25
  class Retriever() :
26
+
27
+ def __init__(self , collection_name = 'pdf_rag') :
28
+
29
  self.collection_name = collection_name
30
+
31
  self.client = QdrantClient(url=qdrant_url , api_key=qdrant_api_key)
32
 
33
+
34
+
35
+ self.dense_model = TextEmbedding(model_name="sentence-transformers/all-MiniLM-L6-v2")
36
+
37
+ self.sparse_model = SparseTextEmbedding(model_name="Qdrant/bm25")
38
+
39
+
40
+
41
+ self.reranker = TextCrossEncoder(model_name="Xenova/ms-marco-MiniLM-L-6-v2")
42
+
43
+
44
+
 
 
 
 
 
 
 
 
 
 
45
 
46
 
47
  def retrieve(self , query : str , user_id : str) :
 
 
 
 
 
 
 
 
48
 
49
+
50
+
51
+ dense_query_vector = list(self.dense_model.embed([query]))[0]
52
+
53
+
54
 
55
  sparse_query = list(self.sparse_model.embed([query]))[0]
56
+
57
  sparse_query_vector = models.SparseVector(indices=sparse_query.indices,
58
+
59
  values=sparse_query.values)
60
+
61
+
62
+
63
  user_filter = models.Filter(must=[models.FieldCondition(key="user_id" , match=models.MatchValue(value=user_id))])
64
 
65
+
66
+
67
  results = self.client.query_points(collection_name=self.collection_name,
68
+
69
  prefetch=[models.Prefetch(
70
+
71
  query=dense_query_vector,
72
+
73
  limit=20,
74
+
75
  using='dense',
76
+
77
  filter=user_filter
78
+
79
  ),
80
+
81
  models.Prefetch(
82
+
83
  query=sparse_query_vector,
84
+
85
  using='sparse',
86
+
87
  limit=20,
88
+
89
  filter=user_filter
90
+
91
  )],
92
+
93
  query=models.FusionQuery(fusion=models.Fusion.RRF),
94
+
95
  limit=20)
96
+
97
+
98
+
99
  texts = [point.payload.get('text' , '') for point in results.points]
100
 
101
+
102
+
103
+ rerank_scores = list(self.reranker.rerank(query, texts))
104
+
105
+
106
 
107
  reranked_results = []
108
+
109
  for point, score in zip(results.points, rerank_scores):
110
+
111
  reranked_results.append({
112
+
113
  "text": point.payload.get("text"),
114
+
115
  "source": point.payload.get("source"),
116
+
117
  "pages": point.payload.get("pages"),
118
+
119
  "section": point.payload.get("section"),
120
+
121
  "original_qdrant_score": point.score,
122
+
123
  "rerank_score": float(score)
124
+
125
  })
126
 
127
+
128
+
129
  reranked_results.sort(key=lambda x: x["rerank_score"], reverse=True)
130
 
131
+
132
+
133
+ final_top_results = reranked_results[:5]
134
+
135
+
136
+
137
+ return final_top_results
138
+