| import numpy as np |
| import pickle |
| import os |
| import torch |
| from typing import List, Dict, Any |
| from sentence_transformers import SentenceTransformer |
| from config.config import Config |
|
|
|
|
| |
| os.environ["TOKENIZERS_PARALLELISM"] = "false" |
| os.environ["CUDA_VISIBLE_DEVICES"] = "" |
|
|
|
|
| class Embedder: |
| def __init__(self): |
| self.config = Config() |
|
|
| |
| device = "cpu" |
|
|
| |
| if torch.cuda.is_available(): |
| print("CUDA is available, but we're forcing the use of CPU.") |
|
|
| try: |
| print(f"Loading model: {self.config.EMBEDDING_MODEL} on {device}") |
| |
| self.model = SentenceTransformer(self.config.EMBEDDING_MODEL, device=device) |
|
|
| except Exception as e: |
| raise RuntimeError(f"Failed to load SentenceTransformer model: {str(e)}") |
|
|
| self.model_path = "data/processed/sentence_transformer.pkl" |
| |
|
|
| def embed_texts(self, texts: List[str]) -> List[List[float]]: |
| """ |
| Generate embeddings for a list of texts using Sentence Transformers. |
| |
| Args: |
| texts: List of text strings to embed |
| |
| Returns: |
| List of embedding vectors |
| """ |
| if not texts: |
| return [] |
|
|
| try: |
| |
| embeddings = self.model.encode(texts, convert_to_numpy=True) |
| return embeddings.tolist() |
| except Exception as e: |
| raise RuntimeError(f"Failed to generate embeddings: {str(e)}") |
|
|
| def embed_chunks(self, chunks: List[Dict[str, Any]]) -> List[Dict[str, Any]]: |
| """ |
| Generate embeddings for document chunks and add to chunk metadata. |
| |
| Args: |
| chunks: List of chunk dictionaries |
| |
| Returns: |
| List of chunks with embeddings added |
| """ |
| if not chunks: |
| return [] |
|
|
| texts = [chunk["text"] for chunk in chunks] |
| embeddings = self.embed_texts(texts) |
|
|
| for chunk, embedding in zip(chunks, embeddings): |
| chunk["embedding"] = embedding |
|
|
| return chunks |
|
|
| def embed_query(self, query: str) -> List[float]: |
| """ |
| Generate embedding for a single query. |
| |
| Args: |
| query: Query text |
| |
| Returns: |
| Query embedding vector |
| """ |
| embeddings = self.embed_texts([query]) |
| return embeddings[0] if embeddings else [] |
|
|
| |
| def fit_on_texts(self, texts: List[str]) -> None: |
| pass |
|
|
| def save_vectorizer(self) -> None: |
| pass |
|
|
| def load_vectorizer(self) -> bool: |
| return True |
|
|