Spaces:
Sleeping
Sleeping
Commit ·
0f95a58
1
Parent(s): 46ccd5d
fix: qdrant deploy
Browse files- demo/main.py +63 -23
- src/scientific_rag/infrastructure/qdrant.py +27 -10
demo/main.py
CHANGED
|
@@ -1,13 +1,11 @@
|
|
| 1 |
-
import json
|
| 2 |
import os
|
| 3 |
-
from pathlib import Path
|
| 4 |
import sys
|
|
|
|
| 5 |
from typing import Any
|
| 6 |
|
| 7 |
import gradio as gr
|
| 8 |
from loguru import logger
|
| 9 |
|
| 10 |
-
|
| 11 |
# Auto-configure for HF Spaces
|
| 12 |
if os.getenv("SPACE_ID"): # Detect HF Spaces environment
|
| 13 |
os.environ.setdefault("QDRANT_URL", ":memory:")
|
|
@@ -16,11 +14,6 @@ if os.getenv("SPACE_ID"): # Detect HF Spaces environment
|
|
| 16 |
sys.path.insert(0, str(Path(__file__).parent.parent / "src"))
|
| 17 |
|
| 18 |
from scientific_rag.application.rag.pipeline import RAGPipeline
|
| 19 |
-
from scientific_rag.domain.documents import PaperChunk
|
| 20 |
-
from scientific_rag.domain.queries import Query, QueryFilters
|
| 21 |
-
from scientific_rag.domain.types import DataSource, SectionType
|
| 22 |
-
from scientific_rag.settings import settings
|
| 23 |
-
|
| 24 |
|
| 25 |
MAIN_HEADER = """
|
| 26 |
<div style="text-align: center; margin-bottom: 40px;">
|
|
@@ -89,7 +82,9 @@ class RAGPipelineWrapper:
|
|
| 89 |
raise ValueError("Please enter your API key.")
|
| 90 |
|
| 91 |
if not use_bm25 and not use_dense:
|
| 92 |
-
raise ValueError(
|
|
|
|
|
|
|
| 93 |
|
| 94 |
if top_k < 1 or top_k > 50:
|
| 95 |
raise ValueError("Top-K must be between 1 and 20.")
|
|
@@ -103,7 +98,9 @@ class RAGPipelineWrapper:
|
|
| 103 |
)
|
| 104 |
|
| 105 |
if provider not in LLM_PROVIDERS:
|
| 106 |
-
raise ValueError(
|
|
|
|
|
|
|
| 107 |
|
| 108 |
if model not in LLM_PROVIDERS[provider]["models"]:
|
| 109 |
raise ValueError(
|
|
@@ -146,7 +143,9 @@ class RAGPipelineWrapper:
|
|
| 146 |
|
| 147 |
return answer, chunks_info
|
| 148 |
|
| 149 |
-
def _format_answer(
|
|
|
|
|
|
|
| 150 |
"""Format RAG response as markdown."""
|
| 151 |
lines = []
|
| 152 |
lines.append(response.answer)
|
|
@@ -163,15 +162,25 @@ class RAGPipelineWrapper:
|
|
| 163 |
metadata_badges.append(
|
| 164 |
f'<span class="metadata-badge">📊 Display Chunks: {min(display_chunks, len(response.retrieved_chunks))}</span>'
|
| 165 |
)
|
| 166 |
-
metadata_badges.append(
|
| 167 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 168 |
|
| 169 |
if response.used_filters:
|
| 170 |
-
filters_str = ", ".join(
|
|
|
|
|
|
|
| 171 |
if filters_str:
|
| 172 |
-
metadata_badges.append(
|
|
|
|
|
|
|
| 173 |
|
| 174 |
-
lines.append(
|
|
|
|
|
|
|
| 175 |
|
| 176 |
return "\n".join(lines)
|
| 177 |
|
|
@@ -208,8 +217,16 @@ def process_query(
|
|
| 208 |
) -> tuple[str, str, gr.update, gr.update, gr.update]:
|
| 209 |
try:
|
| 210 |
if not rag_pipeline:
|
| 211 |
-
error_msg =
|
| 212 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 213 |
|
| 214 |
answer, chunks = rag_pipeline.process_query(
|
| 215 |
query=query,
|
|
@@ -439,9 +456,13 @@ Cross-encoder model to improve result relevance
|
|
| 439 |
)
|
| 440 |
|
| 441 |
with gr.Row():
|
| 442 |
-
loading_status = gr.Markdown(
|
|
|
|
|
|
|
| 443 |
|
| 444 |
-
with gr.Group(
|
|
|
|
|
|
|
| 445 |
gr.Markdown("## 📝 Example Questions")
|
| 446 |
|
| 447 |
gr.HTML("""
|
|
@@ -488,7 +509,13 @@ Cross-encoder model to improve result relevance
|
|
| 488 |
),
|
| 489 |
),
|
| 490 |
inputs=[],
|
| 491 |
-
outputs=[
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 492 |
).then(
|
| 493 |
fn=process_query,
|
| 494 |
inputs=[
|
|
@@ -505,7 +532,13 @@ Cross-encoder model to improve result relevance
|
|
| 505 |
expansion_count,
|
| 506 |
display_chunks,
|
| 507 |
],
|
| 508 |
-
outputs=[
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 509 |
)
|
| 510 |
|
| 511 |
clear_btn.click(
|
|
@@ -518,7 +551,14 @@ Cross-encoder model to improve result relevance
|
|
| 518 |
gr.update(value="", visible=False),
|
| 519 |
),
|
| 520 |
inputs=[],
|
| 521 |
-
outputs=[
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 522 |
)
|
| 523 |
|
| 524 |
return demo
|
|
|
|
|
|
|
| 1 |
import os
|
|
|
|
| 2 |
import sys
|
| 3 |
+
from pathlib import Path
|
| 4 |
from typing import Any
|
| 5 |
|
| 6 |
import gradio as gr
|
| 7 |
from loguru import logger
|
| 8 |
|
|
|
|
| 9 |
# Auto-configure for HF Spaces
|
| 10 |
if os.getenv("SPACE_ID"): # Detect HF Spaces environment
|
| 11 |
os.environ.setdefault("QDRANT_URL", ":memory:")
|
|
|
|
| 14 |
sys.path.insert(0, str(Path(__file__).parent.parent / "src"))
|
| 15 |
|
| 16 |
from scientific_rag.application.rag.pipeline import RAGPipeline
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
|
| 18 |
MAIN_HEADER = """
|
| 19 |
<div style="text-align: center; margin-bottom: 40px;">
|
|
|
|
| 82 |
raise ValueError("Please enter your API key.")
|
| 83 |
|
| 84 |
if not use_bm25 and not use_dense:
|
| 85 |
+
raise ValueError(
|
| 86 |
+
"Please enable at least one retrieval method (BM25 or Dense)."
|
| 87 |
+
)
|
| 88 |
|
| 89 |
if top_k < 1 or top_k > 50:
|
| 90 |
raise ValueError("Top-K must be between 1 and 20.")
|
|
|
|
| 98 |
)
|
| 99 |
|
| 100 |
if provider not in LLM_PROVIDERS:
|
| 101 |
+
raise ValueError(
|
| 102 |
+
f"Invalid provider: {provider}. Must be one of {list(LLM_PROVIDERS.keys())}"
|
| 103 |
+
)
|
| 104 |
|
| 105 |
if model not in LLM_PROVIDERS[provider]["models"]:
|
| 106 |
raise ValueError(
|
|
|
|
| 143 |
|
| 144 |
return answer, chunks_info
|
| 145 |
|
| 146 |
+
def _format_answer(
|
| 147 |
+
self, response, provider: str, model: str, display_chunks: int
|
| 148 |
+
) -> str:
|
| 149 |
"""Format RAG response as markdown."""
|
| 150 |
lines = []
|
| 151 |
lines.append(response.answer)
|
|
|
|
| 162 |
metadata_badges.append(
|
| 163 |
f'<span class="metadata-badge">📊 Display Chunks: {min(display_chunks, len(response.retrieved_chunks))}</span>'
|
| 164 |
)
|
| 165 |
+
metadata_badges.append(
|
| 166 |
+
f'<span class="metadata-badge">⏱️ Execution Time: {response.execution_time:.2f}s</span>'
|
| 167 |
+
)
|
| 168 |
+
metadata_badges.append(
|
| 169 |
+
f'<span class="metadata-badge">🤖 Model: {provider} / {model}</span>'
|
| 170 |
+
)
|
| 171 |
|
| 172 |
if response.used_filters:
|
| 173 |
+
filters_str = ", ".join(
|
| 174 |
+
[f"{k}={v}" for k, v in response.used_filters.items() if v != "any"]
|
| 175 |
+
)
|
| 176 |
if filters_str:
|
| 177 |
+
metadata_badges.append(
|
| 178 |
+
f'<span class="metadata-badge">🔎 Filters: {filters_str}</span>'
|
| 179 |
+
)
|
| 180 |
|
| 181 |
+
lines.append(
|
| 182 |
+
'<div class="metadata-container">' + " ".join(metadata_badges) + "</div>"
|
| 183 |
+
)
|
| 184 |
|
| 185 |
return "\n".join(lines)
|
| 186 |
|
|
|
|
| 217 |
) -> tuple[str, str, gr.update, gr.update, gr.update]:
|
| 218 |
try:
|
| 219 |
if not rag_pipeline:
|
| 220 |
+
error_msg = (
|
| 221 |
+
"⚠️ **System Error**: RAG Pipeline not initialized. Please check logs."
|
| 222 |
+
)
|
| 223 |
+
return (
|
| 224 |
+
error_msg,
|
| 225 |
+
"",
|
| 226 |
+
gr.update(visible=False),
|
| 227 |
+
gr.update(visible=True),
|
| 228 |
+
gr.update(value="", visible=False),
|
| 229 |
+
)
|
| 230 |
|
| 231 |
answer, chunks = rag_pipeline.process_query(
|
| 232 |
query=query,
|
|
|
|
| 456 |
)
|
| 457 |
|
| 458 |
with gr.Row():
|
| 459 |
+
loading_status = gr.Markdown(
|
| 460 |
+
value="", visible=False, elem_classes="loading-indicator"
|
| 461 |
+
)
|
| 462 |
|
| 463 |
+
with gr.Group(
|
| 464 |
+
visible=True, elem_classes="examples-section"
|
| 465 |
+
) as examples_section:
|
| 466 |
gr.Markdown("## 📝 Example Questions")
|
| 467 |
|
| 468 |
gr.HTML("""
|
|
|
|
| 509 |
),
|
| 510 |
),
|
| 511 |
inputs=[],
|
| 512 |
+
outputs=[
|
| 513 |
+
answer_output,
|
| 514 |
+
chunks_output,
|
| 515 |
+
examples_section,
|
| 516 |
+
answer_section,
|
| 517 |
+
loading_status,
|
| 518 |
+
],
|
| 519 |
).then(
|
| 520 |
fn=process_query,
|
| 521 |
inputs=[
|
|
|
|
| 532 |
expansion_count,
|
| 533 |
display_chunks,
|
| 534 |
],
|
| 535 |
+
outputs=[
|
| 536 |
+
answer_output,
|
| 537 |
+
chunks_output,
|
| 538 |
+
examples_section,
|
| 539 |
+
answer_section,
|
| 540 |
+
loading_status,
|
| 541 |
+
],
|
| 542 |
)
|
| 543 |
|
| 544 |
clear_btn.click(
|
|
|
|
| 551 |
gr.update(value="", visible=False),
|
| 552 |
),
|
| 553 |
inputs=[],
|
| 554 |
+
outputs=[
|
| 555 |
+
query,
|
| 556 |
+
answer_output,
|
| 557 |
+
chunks_output,
|
| 558 |
+
examples_section,
|
| 559 |
+
answer_section,
|
| 560 |
+
loading_status,
|
| 561 |
+
],
|
| 562 |
)
|
| 563 |
|
| 564 |
return demo
|
src/scientific_rag/infrastructure/qdrant.py
CHANGED
|
@@ -1,7 +1,5 @@
|
|
| 1 |
-
from collections.abc import Sequence
|
| 2 |
from typing import Any
|
| 3 |
|
| 4 |
-
from fastembed import SparseTextEmbedding
|
| 5 |
from loguru import logger
|
| 6 |
from qdrant_client import QdrantClient as SyncQdrantClient
|
| 7 |
from qdrant_client.models import (
|
|
@@ -10,7 +8,6 @@ from qdrant_client.models import (
|
|
| 10 |
Filter,
|
| 11 |
MatchValue,
|
| 12 |
Modifier,
|
| 13 |
-
NamedSparseVector,
|
| 14 |
PointStruct,
|
| 15 |
SparseIndexParams,
|
| 16 |
SparseVector,
|
|
@@ -30,14 +27,24 @@ class QdrantService:
|
|
| 30 |
self.collection_name = settings.qdrant_collection_name
|
| 31 |
|
| 32 |
logger.info(f"Initializing Qdrant client: {self.url}")
|
| 33 |
-
self.client = SyncQdrantClient(url=self.url, api_key=self.api_key, timeout=30)
|
| 34 |
|
| 35 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
if self.client.collection_exists(self.collection_name):
|
| 37 |
logger.info(f"Collection '{self.collection_name}' already exists")
|
| 38 |
return
|
| 39 |
|
| 40 |
-
logger.info(
|
|
|
|
|
|
|
| 41 |
self.client.create_collection(
|
| 42 |
collection_name=self.collection_name,
|
| 43 |
vectors_config={
|
|
@@ -61,7 +68,9 @@ class QdrantService:
|
|
| 61 |
)
|
| 62 |
logger.info(f"Collection '{self.collection_name}' created with indexes")
|
| 63 |
|
| 64 |
-
def upsert_chunks(
|
|
|
|
|
|
|
| 65 |
if not chunks:
|
| 66 |
return 0
|
| 67 |
|
|
@@ -74,7 +83,9 @@ class QdrantService:
|
|
| 74 |
|
| 75 |
if sparse_embeddings and i < len(sparse_embeddings):
|
| 76 |
sparse = sparse_embeddings[i]
|
| 77 |
-
vectors["bm25"] = SparseVector(
|
|
|
|
|
|
|
| 78 |
|
| 79 |
points.append(
|
| 80 |
PointStruct(
|
|
@@ -163,11 +174,17 @@ class QdrantService:
|
|
| 163 |
return None
|
| 164 |
|
| 165 |
must_conditions = []
|
| 166 |
-
target_list =
|
|
|
|
|
|
|
| 167 |
|
| 168 |
for item in target_list:
|
| 169 |
if "key" in item and "match" in item:
|
| 170 |
-
must_conditions.append(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 171 |
|
| 172 |
return Filter(must=must_conditions) if must_conditions else None
|
| 173 |
|
|
|
|
|
|
|
| 1 |
from typing import Any
|
| 2 |
|
|
|
|
| 3 |
from loguru import logger
|
| 4 |
from qdrant_client import QdrantClient as SyncQdrantClient
|
| 5 |
from qdrant_client.models import (
|
|
|
|
| 8 |
Filter,
|
| 9 |
MatchValue,
|
| 10 |
Modifier,
|
|
|
|
| 11 |
PointStruct,
|
| 12 |
SparseIndexParams,
|
| 13 |
SparseVector,
|
|
|
|
| 27 |
self.collection_name = settings.qdrant_collection_name
|
| 28 |
|
| 29 |
logger.info(f"Initializing Qdrant client: {self.url}")
|
|
|
|
| 30 |
|
| 31 |
+
if self.url == ":memory:":
|
| 32 |
+
self.client = SyncQdrantClient(location=":memory:", timeout=30)
|
| 33 |
+
else:
|
| 34 |
+
self.client = SyncQdrantClient(
|
| 35 |
+
url=self.url, api_key=self.api_key, timeout=30
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
def create_collection(
|
| 39 |
+
self, vector_size: int = 384, distance: Distance = Distance.COSINE
|
| 40 |
+
) -> None:
|
| 41 |
if self.client.collection_exists(self.collection_name):
|
| 42 |
logger.info(f"Collection '{self.collection_name}' already exists")
|
| 43 |
return
|
| 44 |
|
| 45 |
+
logger.info(
|
| 46 |
+
f"Creating collection '{self.collection_name}' with dense and sparse vectors"
|
| 47 |
+
)
|
| 48 |
self.client.create_collection(
|
| 49 |
collection_name=self.collection_name,
|
| 50 |
vectors_config={
|
|
|
|
| 68 |
)
|
| 69 |
logger.info(f"Collection '{self.collection_name}' created with indexes")
|
| 70 |
|
| 71 |
+
def upsert_chunks(
|
| 72 |
+
self, chunks: list[PaperChunk], sparse_embeddings: list[Any] | None = None
|
| 73 |
+
) -> int:
|
| 74 |
if not chunks:
|
| 75 |
return 0
|
| 76 |
|
|
|
|
| 83 |
|
| 84 |
if sparse_embeddings and i < len(sparse_embeddings):
|
| 85 |
sparse = sparse_embeddings[i]
|
| 86 |
+
vectors["bm25"] = SparseVector(
|
| 87 |
+
indices=sparse.indices.tolist(), values=sparse.values.tolist()
|
| 88 |
+
)
|
| 89 |
|
| 90 |
points.append(
|
| 91 |
PointStruct(
|
|
|
|
| 174 |
return None
|
| 175 |
|
| 176 |
must_conditions = []
|
| 177 |
+
target_list = (
|
| 178 |
+
filter_dict.get("must", []) if "must" in filter_dict else [filter_dict]
|
| 179 |
+
)
|
| 180 |
|
| 181 |
for item in target_list:
|
| 182 |
if "key" in item and "match" in item:
|
| 183 |
+
must_conditions.append(
|
| 184 |
+
FieldCondition(
|
| 185 |
+
key=item["key"], match=MatchValue(value=item["match"]["value"])
|
| 186 |
+
)
|
| 187 |
+
)
|
| 188 |
|
| 189 |
return Filter(must=must_conditions) if must_conditions else None
|
| 190 |
|