0504ankitsharma commited on
Commit
9a32e55
·
verified ·
1 Parent(s): c90b18d

Update app/services/embedding_service.py

Browse files
Files changed (1) hide show
  1. app/services/embedding_service.py +26 -38
app/services/embedding_service.py CHANGED
@@ -1,6 +1,6 @@
1
- from typing import List
2
  import logging
3
- from openai import OpenAI
 
4
  from app.config import settings
5
 
6
  logger = logging.getLogger(__name__)
@@ -8,61 +8,49 @@ logger = logging.getLogger(__name__)
8
  class EmbeddingService:
9
  def __init__(self):
10
  try:
11
- self.client = OpenAI(api_key=settings.OPENAI_API_KEY)
12
- self.model_name = settings.EMBEDDING_MODEL or "llama-text-embed-v2"
13
- self.dimension = int(settings.PINECONE_DIMENSION)
14
- logger.info(f"🔹 Using embedding model: {self.model_name}")
15
  except Exception as e:
16
- logger.error(f"Error initializing embedding service: {e}")
17
  raise
18
 
19
  async def embed_text(self, text: str) -> List[float]:
20
- """Generate embeddings for a single text using the Llama model."""
21
  try:
22
- response = self.client.embeddings.create(
23
  model=self.model_name,
24
- input=text
25
  )
26
- embedding = response.data[0].embedding
27
 
28
- # Ensure correct dimensionality
29
- if len(embedding) < self.dimension:
30
- embedding += [0.0] * (self.dimension - len(embedding))
31
- elif len(embedding) > self.dimension:
32
- embedding = embedding[:self.dimension]
 
33
 
34
  return embedding
35
  except Exception as e:
36
- logger.error(f"Error generating embedding: {e}")
37
  raise
38
 
39
  async def embed_batch(self, texts: List[str]) -> List[List[float]]:
40
- """Generate embeddings for a batch of texts."""
41
  try:
42
- response = self.client.embeddings.create(
43
- model=self.model_name,
44
- input=texts
45
- )
46
- embeddings = [d.embedding for d in response.data]
47
-
48
- # Pad/truncate each embedding
49
- fixed_embeddings = []
50
- for emb in embeddings:
51
- if len(emb) < self.dimension:
52
- emb += [0.0] * (self.dimension - len(emb))
53
- elif len(emb) > self.dimension:
54
- emb = emb[:self.dimension]
55
- fixed_embeddings.append(emb)
56
-
57
- return fixed_embeddings
58
  except Exception as e:
59
  logger.error(f"Error generating batch embeddings: {e}")
60
  raise
61
 
62
- async def encode_product(self, product) -> List[float]:
63
- """Combine product info for embedding."""
64
- text = f"{product.title or ''} {product.brand or ''} {product.material or ''} {product.color or ''} {' '.join(product.categories) if product.categories else ''}"
65
- return await self.embed_text(text)
66
 
67
  # Global instance
68
  embedding_service = EmbeddingService()
 
 
1
  import logging
2
+ from typing import List
3
+ import google.generativeai as genai
4
  from app.config import settings
5
 
6
  logger = logging.getLogger(__name__)
 
8
  class EmbeddingService:
9
  def __init__(self):
10
  try:
11
+ genai.configure(api_key=settings.GEMINI_API_KEY)
12
+ self.model_name = "models/embedding-001" # Gemini text embedding model
13
+ logger.info(f"🔹 Using Gemini embedding model: {self.model_name}")
 
14
  except Exception as e:
15
+ logger.error(f"Error initializing Gemini embedding service: {e}")
16
  raise
17
 
18
  async def embed_text(self, text: str) -> List[float]:
19
+ """Generate text embeddings using Gemini API"""
20
  try:
21
+ response = genai.embed_content(
22
  model=self.model_name,
23
+ content=text
24
  )
25
+ embedding = response["embedding"]
26
 
27
+ # Ensure vector dimension matches Pinecone index (1024)
28
+ if len(embedding) < settings.PINECONE_DIMENSION:
29
+ padding = [0.0] * (settings.PINECONE_DIMENSION - len(embedding))
30
+ embedding.extend(padding)
31
+ elif len(embedding) > settings.PINECONE_DIMENSION:
32
+ embedding = embedding[:settings.PINECONE_DIMENSION]
33
 
34
  return embedding
35
  except Exception as e:
36
+ logger.error(f"Error generating Gemini embedding: {e}")
37
  raise
38
 
39
  async def embed_batch(self, texts: List[str]) -> List[List[float]]:
40
+ """Generate batch embeddings"""
41
  try:
42
+ embeddings = []
43
+ for text in texts:
44
+ response = genai.embed_content(
45
+ model=self.model_name,
46
+ content=text
47
+ )
48
+ embeddings.append(response["embedding"])
49
+ return embeddings
 
 
 
 
 
 
 
 
50
  except Exception as e:
51
  logger.error(f"Error generating batch embeddings: {e}")
52
  raise
53
 
 
 
 
 
54
 
55
  # Global instance
56
  embedding_service = EmbeddingService()