Spaces:
Sleeping
Sleeping
Developer commited on
Commit ·
e0273cc
1
Parent(s): 7406588
Fix: Load embedding model when loading existing Qdrant collections
Browse files- Qdrant get_collection now retrieves embedding model from document metadata
- Store embedding_model and chunking_strategy in Qdrant document payloads
- Extract embedding model to session state for both ChromaDB and Qdrant
- Switch theme to light mode
- .streamlit/config.toml +1 -1
- streamlit_app.py +21 -0
- vector_store.py +45 -1
.streamlit/config.toml
CHANGED
|
@@ -20,5 +20,5 @@ serverPort = 7860
|
|
| 20 |
level = "warning"
|
| 21 |
|
| 22 |
[theme]
|
| 23 |
-
base = "
|
| 24 |
primaryColor = "#7C3AED"
|
|
|
|
| 20 |
level = "warning"
|
| 21 |
|
| 22 |
[theme]
|
| 23 |
+
base = "light"
|
| 24 |
primaryColor = "#7C3AED"
|
streamlit_app.py
CHANGED
|
@@ -557,6 +557,27 @@ def load_existing_collection(api_key: str, collection_name: str, llm_provider: s
|
|
| 557 |
st.session_state.collection_name = collection_name
|
| 558 |
st.session_state.llm_provider = llm_provider
|
| 559 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 560 |
# Display system prompt and model info
|
| 561 |
provider_icon = "☁️" if llm_provider == "groq" else "🖥️"
|
| 562 |
st.success(f"✅ Collection '{collection_name}' loaded successfully! {provider_icon} Using {llm_provider.upper()}")
|
|
|
|
| 557 |
st.session_state.collection_name = collection_name
|
| 558 |
st.session_state.llm_provider = llm_provider
|
| 559 |
|
| 560 |
+
# Extract embedding model from collection metadata or vector store
|
| 561 |
+
embedding_model_name = None
|
| 562 |
+
|
| 563 |
+
# For ChromaDB: check collection metadata
|
| 564 |
+
if hasattr(vector_store, 'current_collection') and vector_store.current_collection:
|
| 565 |
+
if hasattr(vector_store.current_collection, 'metadata'):
|
| 566 |
+
collection_metadata = vector_store.current_collection.metadata
|
| 567 |
+
if collection_metadata and "embedding_model" in collection_metadata:
|
| 568 |
+
embedding_model_name = collection_metadata["embedding_model"]
|
| 569 |
+
|
| 570 |
+
# For Qdrant or fallback: check if embedding_model was loaded on the vector store
|
| 571 |
+
if not embedding_model_name and hasattr(vector_store, 'embedding_model') and vector_store.embedding_model:
|
| 572 |
+
if hasattr(vector_store.embedding_model, 'model_name'):
|
| 573 |
+
embedding_model_name = vector_store.embedding_model.model_name
|
| 574 |
+
|
| 575 |
+
# Set session state
|
| 576 |
+
if embedding_model_name:
|
| 577 |
+
st.session_state.embedding_model = embedding_model_name
|
| 578 |
+
else:
|
| 579 |
+
st.session_state.embedding_model = None
|
| 580 |
+
|
| 581 |
# Display system prompt and model info
|
| 582 |
provider_icon = "☁️" if llm_provider == "groq" else "🖥️"
|
| 583 |
st.success(f"✅ Collection '{collection_name}' loaded successfully! {provider_icon} Using {llm_provider.upper()}")
|
vector_store.py
CHANGED
|
@@ -532,6 +532,48 @@ class QdrantManager:
|
|
| 532 |
info = self.client.get_collection(collection_name)
|
| 533 |
self.vector_size = info.config.params.vectors.size
|
| 534 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 535 |
print(f"[QDRANT] Loaded collection: {collection_name}")
|
| 536 |
return self.current_collection
|
| 537 |
|
|
@@ -666,7 +708,9 @@ class QdrantManager:
|
|
| 666 |
"question": sample.get("question", ""),
|
| 667 |
"answer": sample.get("answer", ""),
|
| 668 |
"dataset": sample.get("dataset", ""),
|
| 669 |
-
"total_docs": len(documents)
|
|
|
|
|
|
|
| 670 |
})
|
| 671 |
|
| 672 |
# Add all chunks to collection
|
|
|
|
| 532 |
info = self.client.get_collection(collection_name)
|
| 533 |
self.vector_size = info.config.params.vectors.size
|
| 534 |
|
| 535 |
+
# Try to load embedding model from first document's metadata
|
| 536 |
+
embedding_model_name = None
|
| 537 |
+
try:
|
| 538 |
+
# Scroll to get first point
|
| 539 |
+
points, _ = self.client.scroll(
|
| 540 |
+
collection_name=collection_name,
|
| 541 |
+
limit=1,
|
| 542 |
+
with_payload=True
|
| 543 |
+
)
|
| 544 |
+
if points and len(points) > 0:
|
| 545 |
+
payload = points[0].payload
|
| 546 |
+
embedding_model_name = payload.get("embedding_model")
|
| 547 |
+
if "chunking_strategy" in payload:
|
| 548 |
+
self.chunking_strategy = payload["chunking_strategy"]
|
| 549 |
+
except Exception as e:
|
| 550 |
+
print(f"[QDRANT] Warning: Could not retrieve metadata: {e}")
|
| 551 |
+
|
| 552 |
+
# If not found in metadata, try to infer from collection name
|
| 553 |
+
if not embedding_model_name:
|
| 554 |
+
# Collection name format: dataset_strategy_modelname
|
| 555 |
+
# Try common embedding models
|
| 556 |
+
known_models = [
|
| 557 |
+
"all-mpnet-base-v2",
|
| 558 |
+
"all-MiniLM-L6-v2",
|
| 559 |
+
"paraphrase-MiniLM-L6-v2",
|
| 560 |
+
"multi-qa-MiniLM-L6-cos-v1"
|
| 561 |
+
]
|
| 562 |
+
for model in known_models:
|
| 563 |
+
if model.lower().replace("-", "") in collection_name.lower().replace("-", "").replace("_", ""):
|
| 564 |
+
embedding_model_name = f"sentence-transformers/{model}"
|
| 565 |
+
break
|
| 566 |
+
# Default fallback
|
| 567 |
+
if not embedding_model_name:
|
| 568 |
+
embedding_model_name = "sentence-transformers/all-mpnet-base-v2"
|
| 569 |
+
print(f"[QDRANT] Warning: Could not determine embedding model, using default: {embedding_model_name}")
|
| 570 |
+
|
| 571 |
+
# Load the embedding model
|
| 572 |
+
if embedding_model_name:
|
| 573 |
+
self.embedding_model = EmbeddingFactory.create_embedding_model(embedding_model_name)
|
| 574 |
+
self.embedding_model.load_model()
|
| 575 |
+
print(f"[QDRANT] Loaded embedding model: {embedding_model_name}")
|
| 576 |
+
|
| 577 |
print(f"[QDRANT] Loaded collection: {collection_name}")
|
| 578 |
return self.current_collection
|
| 579 |
|
|
|
|
| 708 |
"question": sample.get("question", ""),
|
| 709 |
"answer": sample.get("answer", ""),
|
| 710 |
"dataset": sample.get("dataset", ""),
|
| 711 |
+
"total_docs": len(documents),
|
| 712 |
+
"embedding_model": embedding_model_name,
|
| 713 |
+
"chunking_strategy": chunking_strategy
|
| 714 |
})
|
| 715 |
|
| 716 |
# Add all chunks to collection
|