VcRlAgent commited on
Commit
a123e22
·
1 Parent(s): b49feb6

Change Encoder and Retriever for prefix

Browse files
.env.example CHANGED
@@ -12,7 +12,12 @@ HF_API_URL=https://api-inference.huggingface.co/models/mistralai/Mistral-7B-Inst
12
  HF_TOKEN=your_huggingface_token_here
13
 
14
  # Embedding Model
15
- EMBEDDING_MODEL=sentence-transformers/all-MiniLM-L6-v2
 
 
 
 
 
16
 
17
  # Server Configuration
18
  HOST=0.0.0.0
 
12
  HF_TOKEN=your_huggingface_token_here
13
 
14
  # Embedding Model
15
+ #EMBEDDING_MODEL=sentence-transformers/all-MiniLM-L6-v2
16
+ #EMBEDDING_MODEL=sentence-transformers/multi-qa-MiniLM-L6-cos-v1
17
+ #EMBEDDING_MODEL=BAAI/bge-small-en-v1.5
18
+ EMBEDDING_MODEL=intfloat/e5-large-v2
19
+
20
+
21
 
22
  # Server Configuration
23
  HOST=0.0.0.0
app/config.py CHANGED
@@ -27,7 +27,8 @@ class Settings:
27
  # Embedding Model
28
  #EMBEDDING_MODEL: str = os.getenv("EMBEDDING_MODEL", "sentence-transformers/all-MiniLM-L6-v2")
29
  #EMBEDDING_MODEL: str = os.getenv("EMBEDDING_MODEL", "sentence-transformers/multi-qa-MiniLM-L6-cos-v1")
30
- EMBEDDING_MODEL: str = os.getenv("EMBEDDING_MODEL", "BAAI/bge-small-en-v1.5")
 
31
 
32
  # Server Configuration
33
  HOST: str = os.getenv("HOST", "0.0.0.0")
@@ -40,5 +41,6 @@ class Settings:
40
  # Vector Search
41
  TOP_K: int = 5
42
  SCORE_THRESHOLD: float = 0.0
 
43
 
44
  settings = Settings()
 
27
  # Embedding Model
28
  #EMBEDDING_MODEL: str = os.getenv("EMBEDDING_MODEL", "sentence-transformers/all-MiniLM-L6-v2")
29
  #EMBEDDING_MODEL: str = os.getenv("EMBEDDING_MODEL", "sentence-transformers/multi-qa-MiniLM-L6-cos-v1")
30
+ #EMBEDDING_MODEL: str = os.getenv("EMBEDDING_MODEL", "BAAI/bge-small-en-v1.5")
31
+ EMBEDDING_MODEL: str = os.getenv("EMBEDDING_MODEL", "intfloat/e5-large-v2")
32
 
33
  # Server Configuration
34
  HOST: str = os.getenv("HOST", "0.0.0.0")
 
41
  # Vector Search
42
  TOP_K: int = 5
43
  SCORE_THRESHOLD: float = 0.0
44
+ VECTOR_SIZE = 1024 # Adjust based on embedding model used
45
 
46
  settings = Settings()
app/services/embeddings copy.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Embedding generation service using sentence-transformers"""
2
+ from sentence_transformers import SentenceTransformer
3
+ from typing import List
4
+ import numpy as np
5
+ from app.config import settings
6
+ from app.utils.logger import setup_logger
7
+
8
+ logger = setup_logger(__name__)
9
+
10
+ class EmbeddingService:
11
+ """Generate embeddings for text using sentence-transformers"""
12
+
13
+ def __init__(self):
14
+ """Initialize the embedding model"""
15
+ logger.info(f"Loading embedding model: {settings.EMBEDDING_MODEL}")
16
+ self.model = SentenceTransformer(settings.EMBEDDING_MODEL)
17
+ self.dimension = self.model.get_sentence_embedding_dimension()
18
+ logger.info(f"Embedding dimension: {self.dimension}")
19
+
20
+ def embed_text(self, text: str) -> List[float]:
21
+ """Generate embedding for a single text"""
22
+ embedding = self.model.encode(text, convert_to_numpy=True, normalize_embeddings=True)
23
+ #logger.debug(f"Generated embedding for text: {embedding}")
24
+ return embedding.tolist()
25
+
26
+
27
+ def embed_batch(self, texts: List[str], batch_size: int = 32) -> List[List[float]]:
28
+ """Generate embeddings for a batch of texts"""
29
+ logger.info(f"Embedding {len(texts)} texts...")
30
+ embeddings = self.model.encode(
31
+ texts,
32
+ batch_size=batch_size,
33
+ show_progress_bar=True,
34
+ convert_to_numpy=True,
35
+ normalize_embeddings=True
36
+ )
37
+ return embeddings.tolist()
38
+
39
+ def get_dimension(self) -> int:
40
+ """Return embedding dimension"""
41
+ logger.debug(f"Embedding dimension requested: {self.dimension}")
42
+ return self.dimension
43
+
44
+ # Global instance
45
+ embedding_service = EmbeddingService()
app/services/embeddings.py CHANGED
@@ -1,4 +1,4 @@
1
- """Embedding generation service using sentence-transformers"""
2
  from sentence_transformers import SentenceTransformer
3
  from typing import List
4
  import numpy as np
@@ -8,38 +8,64 @@ from app.utils.logger import setup_logger
8
  logger = setup_logger(__name__)
9
 
10
  class EmbeddingService:
11
- """Generate embeddings for text using sentence-transformers"""
12
-
 
 
 
 
13
  def __init__(self):
14
- """Initialize the embedding model"""
15
  logger.info(f"Loading embedding model: {settings.EMBEDDING_MODEL}")
16
  self.model = SentenceTransformer(settings.EMBEDDING_MODEL)
17
  self.dimension = self.model.get_sentence_embedding_dimension()
18
  logger.info(f"Embedding dimension: {self.dimension}")
19
-
20
- def embed_text(self, text: str) -> List[float]:
21
- """Generate embedding for a single text"""
22
- embedding = self.model.encode(text, convert_to_numpy=True, normalize_embeddings=True)
23
- #logger.debug(f"Generated embedding for text: {embedding}")
 
 
 
 
 
 
 
 
 
 
24
  return embedding.tolist()
25
-
26
-
27
- def embed_batch(self, texts: List[str], batch_size: int = 32) -> List[List[float]]:
28
- """Generate embeddings for a batch of texts"""
29
- logger.info(f"Embedding {len(texts)} texts...")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  embeddings = self.model.encode(
31
- texts,
32
  batch_size=batch_size,
33
  show_progress_bar=True,
34
  convert_to_numpy=True,
35
- normalize_embeddings=True
36
  )
37
  return embeddings.tolist()
38
-
39
  def get_dimension(self) -> int:
40
- """Return embedding dimension"""
41
- logger.debug(f"Embedding dimension requested: {self.dimension}")
42
  return self.dimension
43
 
44
  # Global instance
45
- embedding_service = EmbeddingService()
 
1
+ """Embedding generation service using intfloat/e5-large-v2"""
2
  from sentence_transformers import SentenceTransformer
3
  from typing import List
4
  import numpy as np
 
8
  logger = setup_logger(__name__)
9
 
10
  class EmbeddingService:
11
+ """
12
+ Generate embeddings for text using intfloat/e5-large-v2.
13
+ Automatically prefixes 'query:' or 'passage:' as recommended
14
+ for retrieval tasks.
15
+ """
16
+
17
  def __init__(self):
 
18
  logger.info(f"Loading embedding model: {settings.EMBEDDING_MODEL}")
19
  self.model = SentenceTransformer(settings.EMBEDDING_MODEL)
20
  self.dimension = self.model.get_sentence_embedding_dimension()
21
  logger.info(f"Embedding dimension: {self.dimension}")
22
+
23
+ def embed_text(self, text: str, is_query: bool = False) -> List[float]:
24
+ """Generate embedding for a single text (query or passage)."""
25
+ if not text or not text.strip():
26
+ logger.warning("Empty text passed to embed_text()")
27
+ return []
28
+
29
+ prefix = "query: " if is_query else "passage: "
30
+ formatted_text = prefix + text.strip()
31
+
32
+ embedding = self.model.encode(
33
+ formatted_text,
34
+ convert_to_numpy=True,
35
+ normalize_embeddings=True,
36
+ )
37
  return embedding.tolist()
38
+
39
+ def embed_batch(
40
+ self,
41
+ texts: List[str],
42
+ batch_size: int = 32,
43
+ is_query: bool = False,
44
+ ) -> List[List[float]]:
45
+ """Generate embeddings for a batch of texts (queries or passages)."""
46
+ if not texts:
47
+ return []
48
+
49
+ prefix = "query: " if is_query else "passage: "
50
+ prefixed_texts = [prefix + t.strip() for t in texts]
51
+
52
+ logger.info(
53
+ f"Embedding {len(prefixed_texts)} texts using {settings.EMBEDDING_MODEL} "
54
+ f"(is_query={is_query})"
55
+ )
56
+
57
  embeddings = self.model.encode(
58
+ prefixed_texts,
59
  batch_size=batch_size,
60
  show_progress_bar=True,
61
  convert_to_numpy=True,
62
+ normalize_embeddings=True,
63
  )
64
  return embeddings.tolist()
65
+
66
  def get_dimension(self) -> int:
67
+ """Return embedding vector dimension."""
 
68
  return self.dimension
69
 
70
  # Global instance
71
+ embedding_service = EmbeddingService()
app/services/retriever.py CHANGED
@@ -25,8 +25,8 @@ class RetrieverService:
25
 
26
  # Generate query embedding
27
  logger.info(f"Retrieving documents for query: {query}")
28
- query_embedding = self.embedding_service.embed_text(query)
29
- logger.debug(f"Embedded query: {query_embedding}")
30
 
31
  #FAISS
32
  results = self.vector_store.search(
@@ -35,10 +35,6 @@ class RetrieverService:
35
  score_threshold=settings.SCORE_THRESHOLD
36
  )
37
 
38
- '''
39
- logger.debug(f"FAISS total vectors: {self.vector_store.index.ntotal}")
40
- D, I = self.vector_store.index.search(np.array([query_embedding]).astype("float32"), k=3)
41
- logger.debug(f"Distances: {D}, Indices: {I}")
42
  '''
43
  try:
44
  logger.warning(f"FAISS index object: {self.vector_store.index}")
@@ -47,12 +43,13 @@ class RetrieverService:
47
  else:
48
  logger.warning(f"FAISS total vectors: {self.vector_store.index.ntotal}")
49
  D, I = self.vector_store.index.search(
50
- np.array([query_embedding]).astype("float32"), k=3
51
  )
52
  logger.warning(f"Distances: {D}, Indices: {I}")
53
  except Exception as e:
54
  import traceback
55
  logger.error(f"FAISS search error: {e}\n{traceback.format_exc()}")
 
56
 
57
  #Qdrant
58
  # Search vector database
@@ -60,8 +57,7 @@ class RetrieverService:
60
  # query_vector=query_embedding,
61
  # limit=top_k,
62
  # score_threshold=settings.SCORE_THRESHOLD
63
- # )
64
-
65
 
66
  logger.info(f"Retrieved {len(results)} documents")
67
  return results
 
25
 
26
  # Generate query embedding
27
  logger.info(f"Retrieving documents for query: {query}")
28
+ query_embedding = self.embedding_service.embed_text(query,is_query=True)
29
+ #logger.debug(f"Embedded query: {query_embedding}")
30
 
31
  #FAISS
32
  results = self.vector_store.search(
 
35
  score_threshold=settings.SCORE_THRESHOLD
36
  )
37
 
 
 
 
 
38
  '''
39
  try:
40
  logger.warning(f"FAISS index object: {self.vector_store.index}")
 
43
  else:
44
  logger.warning(f"FAISS total vectors: {self.vector_store.index.ntotal}")
45
  D, I = self.vector_store.index.search(
46
+ np.array([query_embedding]).astype("float32"), k=top_k
47
  )
48
  logger.warning(f"Distances: {D}, Indices: {I}")
49
  except Exception as e:
50
  import traceback
51
  logger.error(f"FAISS search error: {e}\n{traceback.format_exc()}")
52
+ '''
53
 
54
  #Qdrant
55
  # Search vector database
 
57
  # query_vector=query_embedding,
58
  # limit=top_k,
59
  # score_threshold=settings.SCORE_THRESHOLD
60
+ # )
 
61
 
62
  logger.info(f"Retrieved {len(results)} documents")
63
  return results