Developer commited on
Commit
e0273cc
·
1 Parent(s): 7406588

Fix: Load embedding model when loading existing Qdrant collections

Browse files

- Qdrant get_collection now retrieves embedding model from document metadata
- Store embedding_model and chunking_strategy in Qdrant document payloads
- Extract embedding model to session state for both ChromaDB and Qdrant
- Switch theme to light mode

Files changed (3) hide show
  1. .streamlit/config.toml +1 -1
  2. streamlit_app.py +21 -0
  3. vector_store.py +45 -1
.streamlit/config.toml CHANGED
@@ -20,5 +20,5 @@ serverPort = 7860
20
  level = "warning"
21
 
22
  [theme]
23
- base = "dark"
24
  primaryColor = "#7C3AED"
 
20
  level = "warning"
21
 
22
  [theme]
23
+ base = "light"
24
  primaryColor = "#7C3AED"
streamlit_app.py CHANGED
@@ -557,6 +557,27 @@ def load_existing_collection(api_key: str, collection_name: str, llm_provider: s
557
  st.session_state.collection_name = collection_name
558
  st.session_state.llm_provider = llm_provider
559
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
560
  # Display system prompt and model info
561
  provider_icon = "☁️" if llm_provider == "groq" else "🖥️"
562
  st.success(f"✅ Collection '{collection_name}' loaded successfully! {provider_icon} Using {llm_provider.upper()}")
 
557
  st.session_state.collection_name = collection_name
558
  st.session_state.llm_provider = llm_provider
559
 
560
+ # Extract embedding model from collection metadata or vector store
561
+ embedding_model_name = None
562
+
563
+ # For ChromaDB: check collection metadata
564
+ if hasattr(vector_store, 'current_collection') and vector_store.current_collection:
565
+ if hasattr(vector_store.current_collection, 'metadata'):
566
+ collection_metadata = vector_store.current_collection.metadata
567
+ if collection_metadata and "embedding_model" in collection_metadata:
568
+ embedding_model_name = collection_metadata["embedding_model"]
569
+
570
+ # For Qdrant or fallback: check if embedding_model was loaded on the vector store
571
+ if not embedding_model_name and hasattr(vector_store, 'embedding_model') and vector_store.embedding_model:
572
+ if hasattr(vector_store.embedding_model, 'model_name'):
573
+ embedding_model_name = vector_store.embedding_model.model_name
574
+
575
+ # Set session state
576
+ if embedding_model_name:
577
+ st.session_state.embedding_model = embedding_model_name
578
+ else:
579
+ st.session_state.embedding_model = None
580
+
581
  # Display system prompt and model info
582
  provider_icon = "☁️" if llm_provider == "groq" else "🖥️"
583
  st.success(f"✅ Collection '{collection_name}' loaded successfully! {provider_icon} Using {llm_provider.upper()}")
vector_store.py CHANGED
@@ -532,6 +532,48 @@ class QdrantManager:
532
  info = self.client.get_collection(collection_name)
533
  self.vector_size = info.config.params.vectors.size
534
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
535
  print(f"[QDRANT] Loaded collection: {collection_name}")
536
  return self.current_collection
537
 
@@ -666,7 +708,9 @@ class QdrantManager:
666
  "question": sample.get("question", ""),
667
  "answer": sample.get("answer", ""),
668
  "dataset": sample.get("dataset", ""),
669
- "total_docs": len(documents)
 
 
670
  })
671
 
672
  # Add all chunks to collection
 
532
  info = self.client.get_collection(collection_name)
533
  self.vector_size = info.config.params.vectors.size
534
 
535
+ # Try to load embedding model from first document's metadata
536
+ embedding_model_name = None
537
+ try:
538
+ # Scroll to get first point
539
+ points, _ = self.client.scroll(
540
+ collection_name=collection_name,
541
+ limit=1,
542
+ with_payload=True
543
+ )
544
+ if points and len(points) > 0:
545
+ payload = points[0].payload
546
+ embedding_model_name = payload.get("embedding_model")
547
+ if "chunking_strategy" in payload:
548
+ self.chunking_strategy = payload["chunking_strategy"]
549
+ except Exception as e:
550
+ print(f"[QDRANT] Warning: Could not retrieve metadata: {e}")
551
+
552
+ # If not found in metadata, try to infer from collection name
553
+ if not embedding_model_name:
554
+ # Collection name format: dataset_strategy_modelname
555
+ # Try common embedding models
556
+ known_models = [
557
+ "all-mpnet-base-v2",
558
+ "all-MiniLM-L6-v2",
559
+ "paraphrase-MiniLM-L6-v2",
560
+ "multi-qa-MiniLM-L6-cos-v1"
561
+ ]
562
+ for model in known_models:
563
+ if model.lower().replace("-", "") in collection_name.lower().replace("-", "").replace("_", ""):
564
+ embedding_model_name = f"sentence-transformers/{model}"
565
+ break
566
+ # Default fallback
567
+ if not embedding_model_name:
568
+ embedding_model_name = "sentence-transformers/all-mpnet-base-v2"
569
+ print(f"[QDRANT] Warning: Could not determine embedding model, using default: {embedding_model_name}")
570
+
571
+ # Load the embedding model
572
+ if embedding_model_name:
573
+ self.embedding_model = EmbeddingFactory.create_embedding_model(embedding_model_name)
574
+ self.embedding_model.load_model()
575
+ print(f"[QDRANT] Loaded embedding model: {embedding_model_name}")
576
+
577
  print(f"[QDRANT] Loaded collection: {collection_name}")
578
  return self.current_collection
579
 
 
708
  "question": sample.get("question", ""),
709
  "answer": sample.get("answer", ""),
710
  "dataset": sample.get("dataset", ""),
711
+ "total_docs": len(documents),
712
+ "embedding_model": embedding_model_name,
713
+ "chunking_strategy": chunking_strategy
714
  })
715
 
716
  # Add all chunks to collection