college / embeddings /embedder.py
battulabhaskar543
fixed meta tensor error
81726ee
import numpy as np
import pickle
import os
import torch
from typing import List, Dict, Any
from sentence_transformers import SentenceTransformer
from config.config import Config
# Set environment variables to prevent device issues
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["CUDA_VISIBLE_DEVICES"] = "" # Force CPU
class Embedder:
def __init__(self):
self.config = Config()
# Ensure the device is set to CPU even if CUDA is available
device = "cpu" # Always force CPU usage
# Check if CUDA is available and print a warning (Streamlit Cloud doesn't support GPU)
if torch.cuda.is_available():
print("CUDA is available, but we're forcing the use of CPU.")
try:
print(f"Loading model: {self.config.EMBEDDING_MODEL} on {device}")
# Load the model with the specified device
self.model = SentenceTransformer(self.config.EMBEDDING_MODEL, device=device)
except Exception as e:
raise RuntimeError(f"Failed to load SentenceTransformer model: {str(e)}")
self.model_path = "data/processed/sentence_transformer.pkl"
# Sentence transformers don't need fitting, but we can save/load if needed
def embed_texts(self, texts: List[str]) -> List[List[float]]:
"""
Generate embeddings for a list of texts using Sentence Transformers.
Args:
texts: List of text strings to embed
Returns:
List of embedding vectors
"""
if not texts:
return []
try:
# Ensure embeddings are returned as numpy arrays
embeddings = self.model.encode(texts, convert_to_numpy=True)
return embeddings.tolist()
except Exception as e:
raise RuntimeError(f"Failed to generate embeddings: {str(e)}")
def embed_chunks(self, chunks: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""
Generate embeddings for document chunks and add to chunk metadata.
Args:
chunks: List of chunk dictionaries
Returns:
List of chunks with embeddings added
"""
if not chunks:
return []
texts = [chunk["text"] for chunk in chunks]
embeddings = self.embed_texts(texts)
for chunk, embedding in zip(chunks, embeddings):
chunk["embedding"] = embedding
return chunks
def embed_query(self, query: str) -> List[float]:
"""
Generate embedding for a single query.
Args:
query: Query text
Returns:
Query embedding vector
"""
embeddings = self.embed_texts([query])
return embeddings[0] if embeddings else []
# Legacy methods for compatibility
def fit_on_texts(self, texts: List[str]) -> None:
pass # Not needed for sentence transformers
def save_vectorizer(self) -> None:
pass
def load_vectorizer(self) -> bool:
return True