DenysKovalML commited on
Commit
0f95a58
·
1 Parent(s): 46ccd5d

fix: qdrant deploy

Browse files
demo/main.py CHANGED
@@ -1,13 +1,11 @@
1
- import json
2
  import os
3
- from pathlib import Path
4
  import sys
 
5
  from typing import Any
6
 
7
  import gradio as gr
8
  from loguru import logger
9
 
10
-
11
  # Auto-configure for HF Spaces
12
  if os.getenv("SPACE_ID"): # Detect HF Spaces environment
13
  os.environ.setdefault("QDRANT_URL", ":memory:")
@@ -16,11 +14,6 @@ if os.getenv("SPACE_ID"): # Detect HF Spaces environment
16
  sys.path.insert(0, str(Path(__file__).parent.parent / "src"))
17
 
18
  from scientific_rag.application.rag.pipeline import RAGPipeline
19
- from scientific_rag.domain.documents import PaperChunk
20
- from scientific_rag.domain.queries import Query, QueryFilters
21
- from scientific_rag.domain.types import DataSource, SectionType
22
- from scientific_rag.settings import settings
23
-
24
 
25
  MAIN_HEADER = """
26
  <div style="text-align: center; margin-bottom: 40px;">
@@ -89,7 +82,9 @@ class RAGPipelineWrapper:
89
  raise ValueError("Please enter your API key.")
90
 
91
  if not use_bm25 and not use_dense:
92
- raise ValueError("Please enable at least one retrieval method (BM25 or Dense).")
 
 
93
 
94
  if top_k < 1 or top_k > 50:
95
  raise ValueError("Top-K must be between 1 and 20.")
@@ -103,7 +98,9 @@ class RAGPipelineWrapper:
103
  )
104
 
105
  if provider not in LLM_PROVIDERS:
106
- raise ValueError(f"Invalid provider: {provider}. Must be one of {list(LLM_PROVIDERS.keys())}")
 
 
107
 
108
  if model not in LLM_PROVIDERS[provider]["models"]:
109
  raise ValueError(
@@ -146,7 +143,9 @@ class RAGPipelineWrapper:
146
 
147
  return answer, chunks_info
148
 
149
- def _format_answer(self, response, provider: str, model: str, display_chunks: int) -> str:
 
 
150
  """Format RAG response as markdown."""
151
  lines = []
152
  lines.append(response.answer)
@@ -163,15 +162,25 @@ class RAGPipelineWrapper:
163
  metadata_badges.append(
164
  f'<span class="metadata-badge">📊 Display Chunks: {min(display_chunks, len(response.retrieved_chunks))}</span>'
165
  )
166
- metadata_badges.append(f'<span class="metadata-badge">⏱️ Execution Time: {response.execution_time:.2f}s</span>')
167
- metadata_badges.append(f'<span class="metadata-badge">🤖 Model: {provider} / {model}</span>')
 
 
 
 
168
 
169
  if response.used_filters:
170
- filters_str = ", ".join([f"{k}={v}" for k, v in response.used_filters.items() if v != "any"])
 
 
171
  if filters_str:
172
- metadata_badges.append(f'<span class="metadata-badge">🔎 Filters: {filters_str}</span>')
 
 
173
 
174
- lines.append('<div class="metadata-container">' + " ".join(metadata_badges) + "</div>")
 
 
175
 
176
  return "\n".join(lines)
177
 
@@ -208,8 +217,16 @@ def process_query(
208
  ) -> tuple[str, str, gr.update, gr.update, gr.update]:
209
  try:
210
  if not rag_pipeline:
211
- error_msg = "⚠️ **System Error**: RAG Pipeline not initialized. Please check logs."
212
- return error_msg, "", gr.update(visible=False), gr.update(visible=True), gr.update(value="", visible=False)
 
 
 
 
 
 
 
 
213
 
214
  answer, chunks = rag_pipeline.process_query(
215
  query=query,
@@ -439,9 +456,13 @@ Cross-encoder model to improve result relevance
439
  )
440
 
441
  with gr.Row():
442
- loading_status = gr.Markdown(value="", visible=False, elem_classes="loading-indicator")
 
 
443
 
444
- with gr.Group(visible=True, elem_classes="examples-section") as examples_section:
 
 
445
  gr.Markdown("## 📝 Example Questions")
446
 
447
  gr.HTML("""
@@ -488,7 +509,13 @@ Cross-encoder model to improve result relevance
488
  ),
489
  ),
490
  inputs=[],
491
- outputs=[answer_output, chunks_output, examples_section, answer_section, loading_status],
 
 
 
 
 
 
492
  ).then(
493
  fn=process_query,
494
  inputs=[
@@ -505,7 +532,13 @@ Cross-encoder model to improve result relevance
505
  expansion_count,
506
  display_chunks,
507
  ],
508
- outputs=[answer_output, chunks_output, examples_section, answer_section, loading_status],
 
 
 
 
 
 
509
  )
510
 
511
  clear_btn.click(
@@ -518,7 +551,14 @@ Cross-encoder model to improve result relevance
518
  gr.update(value="", visible=False),
519
  ),
520
  inputs=[],
521
- outputs=[query, answer_output, chunks_output, examples_section, answer_section, loading_status],
 
 
 
 
 
 
 
522
  )
523
 
524
  return demo
 
 
1
  import os
 
2
  import sys
3
+ from pathlib import Path
4
  from typing import Any
5
 
6
  import gradio as gr
7
  from loguru import logger
8
 
 
9
  # Auto-configure for HF Spaces
10
  if os.getenv("SPACE_ID"): # Detect HF Spaces environment
11
  os.environ.setdefault("QDRANT_URL", ":memory:")
 
14
  sys.path.insert(0, str(Path(__file__).parent.parent / "src"))
15
 
16
  from scientific_rag.application.rag.pipeline import RAGPipeline
 
 
 
 
 
17
 
18
  MAIN_HEADER = """
19
  <div style="text-align: center; margin-bottom: 40px;">
 
82
  raise ValueError("Please enter your API key.")
83
 
84
  if not use_bm25 and not use_dense:
85
+ raise ValueError(
86
+ "Please enable at least one retrieval method (BM25 or Dense)."
87
+ )
88
 
89
  if top_k < 1 or top_k > 50:
90
  raise ValueError("Top-K must be between 1 and 20.")
 
98
  )
99
 
100
  if provider not in LLM_PROVIDERS:
101
+ raise ValueError(
102
+ f"Invalid provider: {provider}. Must be one of {list(LLM_PROVIDERS.keys())}"
103
+ )
104
 
105
  if model not in LLM_PROVIDERS[provider]["models"]:
106
  raise ValueError(
 
143
 
144
  return answer, chunks_info
145
 
146
+ def _format_answer(
147
+ self, response, provider: str, model: str, display_chunks: int
148
+ ) -> str:
149
  """Format RAG response as markdown."""
150
  lines = []
151
  lines.append(response.answer)
 
162
  metadata_badges.append(
163
  f'<span class="metadata-badge">📊 Display Chunks: {min(display_chunks, len(response.retrieved_chunks))}</span>'
164
  )
165
+ metadata_badges.append(
166
+ f'<span class="metadata-badge">⏱️ Execution Time: {response.execution_time:.2f}s</span>'
167
+ )
168
+ metadata_badges.append(
169
+ f'<span class="metadata-badge">🤖 Model: {provider} / {model}</span>'
170
+ )
171
 
172
  if response.used_filters:
173
+ filters_str = ", ".join(
174
+ [f"{k}={v}" for k, v in response.used_filters.items() if v != "any"]
175
+ )
176
  if filters_str:
177
+ metadata_badges.append(
178
+ f'<span class="metadata-badge">🔎 Filters: {filters_str}</span>'
179
+ )
180
 
181
+ lines.append(
182
+ '<div class="metadata-container">' + " ".join(metadata_badges) + "</div>"
183
+ )
184
 
185
  return "\n".join(lines)
186
 
 
217
  ) -> tuple[str, str, gr.update, gr.update, gr.update]:
218
  try:
219
  if not rag_pipeline:
220
+ error_msg = (
221
+ "⚠️ **System Error**: RAG Pipeline not initialized. Please check logs."
222
+ )
223
+ return (
224
+ error_msg,
225
+ "",
226
+ gr.update(visible=False),
227
+ gr.update(visible=True),
228
+ gr.update(value="", visible=False),
229
+ )
230
 
231
  answer, chunks = rag_pipeline.process_query(
232
  query=query,
 
456
  )
457
 
458
  with gr.Row():
459
+ loading_status = gr.Markdown(
460
+ value="", visible=False, elem_classes="loading-indicator"
461
+ )
462
 
463
+ with gr.Group(
464
+ visible=True, elem_classes="examples-section"
465
+ ) as examples_section:
466
  gr.Markdown("## 📝 Example Questions")
467
 
468
  gr.HTML("""
 
509
  ),
510
  ),
511
  inputs=[],
512
+ outputs=[
513
+ answer_output,
514
+ chunks_output,
515
+ examples_section,
516
+ answer_section,
517
+ loading_status,
518
+ ],
519
  ).then(
520
  fn=process_query,
521
  inputs=[
 
532
  expansion_count,
533
  display_chunks,
534
  ],
535
+ outputs=[
536
+ answer_output,
537
+ chunks_output,
538
+ examples_section,
539
+ answer_section,
540
+ loading_status,
541
+ ],
542
  )
543
 
544
  clear_btn.click(
 
551
  gr.update(value="", visible=False),
552
  ),
553
  inputs=[],
554
+ outputs=[
555
+ query,
556
+ answer_output,
557
+ chunks_output,
558
+ examples_section,
559
+ answer_section,
560
+ loading_status,
561
+ ],
562
  )
563
 
564
  return demo
src/scientific_rag/infrastructure/qdrant.py CHANGED
@@ -1,7 +1,5 @@
1
- from collections.abc import Sequence
2
  from typing import Any
3
 
4
- from fastembed import SparseTextEmbedding
5
  from loguru import logger
6
  from qdrant_client import QdrantClient as SyncQdrantClient
7
  from qdrant_client.models import (
@@ -10,7 +8,6 @@ from qdrant_client.models import (
10
  Filter,
11
  MatchValue,
12
  Modifier,
13
- NamedSparseVector,
14
  PointStruct,
15
  SparseIndexParams,
16
  SparseVector,
@@ -30,14 +27,24 @@ class QdrantService:
30
  self.collection_name = settings.qdrant_collection_name
31
 
32
  logger.info(f"Initializing Qdrant client: {self.url}")
33
- self.client = SyncQdrantClient(url=self.url, api_key=self.api_key, timeout=30)
34
 
35
- def create_collection(self, vector_size: int = 384, distance: Distance = Distance.COSINE) -> None:
 
 
 
 
 
 
 
 
 
36
  if self.client.collection_exists(self.collection_name):
37
  logger.info(f"Collection '{self.collection_name}' already exists")
38
  return
39
 
40
- logger.info(f"Creating collection '{self.collection_name}' with dense and sparse vectors")
 
 
41
  self.client.create_collection(
42
  collection_name=self.collection_name,
43
  vectors_config={
@@ -61,7 +68,9 @@ class QdrantService:
61
  )
62
  logger.info(f"Collection '{self.collection_name}' created with indexes")
63
 
64
- def upsert_chunks(self, chunks: list[PaperChunk], sparse_embeddings: list[Any] | None = None) -> int:
 
 
65
  if not chunks:
66
  return 0
67
 
@@ -74,7 +83,9 @@ class QdrantService:
74
 
75
  if sparse_embeddings and i < len(sparse_embeddings):
76
  sparse = sparse_embeddings[i]
77
- vectors["bm25"] = SparseVector(indices=sparse.indices.tolist(), values=sparse.values.tolist())
 
 
78
 
79
  points.append(
80
  PointStruct(
@@ -163,11 +174,17 @@ class QdrantService:
163
  return None
164
 
165
  must_conditions = []
166
- target_list = filter_dict.get("must", []) if "must" in filter_dict else [filter_dict]
 
 
167
 
168
  for item in target_list:
169
  if "key" in item and "match" in item:
170
- must_conditions.append(FieldCondition(key=item["key"], match=MatchValue(value=item["match"]["value"])))
 
 
 
 
171
 
172
  return Filter(must=must_conditions) if must_conditions else None
173
 
 
 
1
  from typing import Any
2
 
 
3
  from loguru import logger
4
  from qdrant_client import QdrantClient as SyncQdrantClient
5
  from qdrant_client.models import (
 
8
  Filter,
9
  MatchValue,
10
  Modifier,
 
11
  PointStruct,
12
  SparseIndexParams,
13
  SparseVector,
 
27
  self.collection_name = settings.qdrant_collection_name
28
 
29
  logger.info(f"Initializing Qdrant client: {self.url}")
 
30
 
31
+ if self.url == ":memory:":
32
+ self.client = SyncQdrantClient(location=":memory:", timeout=30)
33
+ else:
34
+ self.client = SyncQdrantClient(
35
+ url=self.url, api_key=self.api_key, timeout=30
36
+ )
37
+
38
+ def create_collection(
39
+ self, vector_size: int = 384, distance: Distance = Distance.COSINE
40
+ ) -> None:
41
  if self.client.collection_exists(self.collection_name):
42
  logger.info(f"Collection '{self.collection_name}' already exists")
43
  return
44
 
45
+ logger.info(
46
+ f"Creating collection '{self.collection_name}' with dense and sparse vectors"
47
+ )
48
  self.client.create_collection(
49
  collection_name=self.collection_name,
50
  vectors_config={
 
68
  )
69
  logger.info(f"Collection '{self.collection_name}' created with indexes")
70
 
71
+ def upsert_chunks(
72
+ self, chunks: list[PaperChunk], sparse_embeddings: list[Any] | None = None
73
+ ) -> int:
74
  if not chunks:
75
  return 0
76
 
 
83
 
84
  if sparse_embeddings and i < len(sparse_embeddings):
85
  sparse = sparse_embeddings[i]
86
+ vectors["bm25"] = SparseVector(
87
+ indices=sparse.indices.tolist(), values=sparse.values.tolist()
88
+ )
89
 
90
  points.append(
91
  PointStruct(
 
174
  return None
175
 
176
  must_conditions = []
177
+ target_list = (
178
+ filter_dict.get("must", []) if "must" in filter_dict else [filter_dict]
179
+ )
180
 
181
  for item in target_list:
182
  if "key" in item and "match" in item:
183
+ must_conditions.append(
184
+ FieldCondition(
185
+ key=item["key"], match=MatchValue(value=item["match"]["value"])
186
+ )
187
+ )
188
 
189
  return Filter(must=must_conditions) if must_conditions else None
190