travel-recommender-api / src /api /modern_bert_utils.py
dcorcoran's picture
Changed embedding pulls
79fc3f2
import os
import torch
import faiss
from sqlalchemy import create_engine
from sqlalchemy.orm import Session, DeclarativeBase, Mapped, mapped_column
from transformers import AutoTokenizer, AutoModel
from dotenv import load_dotenv
import boto3
import io
load_dotenv()
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
MODEL_NAME = "nomic-ai/modernbert-embed-base"
EMBEDDINGS_PATH = "s3://travel-recommender-s3/travel_blog_embeddings.pt"
# -----------------------------
# Load model + tokenizer
# -----------------------------
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModel.from_pretrained(MODEL_NAME).to(DEVICE)
model.eval()
# -----------------------------
# Database model
# -----------------------------
class Base(DeclarativeBase):
pass
class Whole_Blogs(Base):
__tablename__ = "travel_blogs"
id: Mapped[int] = mapped_column(primary_key=True)
blog_url: Mapped[str]
page_url: Mapped[str]
page_title: Mapped[str]
page_description: Mapped[str]
page_author: Mapped[str]
location_name: Mapped[str]
latitude: Mapped[float]
longitude: Mapped[float]
content: Mapped[str]
# -----------------------------
# Cache
# -----------------------------
_cached_posts = None
_index = None
_embeddings = None
# -----------------------------
# Embed helper for queries only
# -----------------------------
def embed_texts(texts_batch):
encoded = tokenizer(
texts_batch,
padding=True,
truncation=True,
max_length=512,
return_tensors="pt"
).to(DEVICE)
with torch.no_grad():
outputs = model(**encoded)
last_hidden = outputs.last_hidden_state
attention_mask = encoded["attention_mask"].unsqueeze(-1)
sum_embeddings = torch.sum(last_hidden * attention_mask, dim=1)
sum_mask = torch.sum(attention_mask, dim=1)
sum_mask = torch.clamp(sum_mask, min=1e-9)
embedding = sum_embeddings / sum_mask
embedding = torch.nn.functional.normalize(embedding, p=2, dim=1)
return embedding.cpu()
# -----------------------------
# Load posts and embeddings
# -----------------------------
def load_embeddings_from_s3(bucket: str, key: str):
s3 = boto3.client("s3")
obj = s3.get_object(Bucket=bucket, Key=key)
buffer = io.BytesIO(obj["Body"].read())
return torch.load(buffer, weights_only=True)
def _load_posts_and_index():
global _cached_posts, _index, _embeddings
if _cached_posts is not None and _index is not None and _embeddings is not None:
return _cached_posts, _index, _embeddings
# Load metadata from DB
database_url = os.getenv("DATABASE_URL")
if not database_url:
raise ValueError("DATABASE_URL not found in environment variables")
engine = create_engine(database_url)
with Session(engine) as session:
_cached_posts = session.query(Whole_Blogs).all()
# Load precomputed embeddings
if EMBEDDINGS_PATH.startswith("s3://"):
path_without_s3 = EMBEDDINGS_PATH[5:]
bucket_name, key = path_without_s3.split("/", 1)
data = load_embeddings_from_s3(bucket_name, key)
else:
data = torch.load(EMBEDDINGS_PATH, weights_only=True)
# Convert embeddings dict to tensor array
# Data is {blog_id: embedding_tensor, ...}
# We need to convert to a stacked tensor aligned with _cached_posts order
embedding_list = []
for post in _cached_posts:
if post.id in data:
embedding_list.append(data[post.id])
else:
raise ValueError(f"Missing embedding for blog post ID {post.id}")
_embeddings = torch.stack(embedding_list)
# Build FAISS index
_index = faiss.IndexFlatL2(_embeddings.shape[1])
_index.add(_embeddings.numpy())
return _cached_posts, _index, _embeddings
# -----------------------------
# Search function
# -----------------------------
def search_modernbert(query: str, top_k: int = 5):
posts, index, embeddings = _load_posts_and_index()
if not query.strip():
return []
q_emb = embed_texts([query]).numpy()
distances, idxs = index.search(q_emb, top_k)
results = []
for i, idx in enumerate(idxs[0]):
post = posts[idx]
content_preview = post.content[:300] + ("..." if len(post.content) > 300 else "")
location_parts = post.location_name.split(",")
country = location_parts[-1].strip() if len(location_parts) > 1 else ""
results.append({
"destination": post.location_name,
"country": country,
"lat": post.latitude,
"lon": post.longitude,
"distance": float(distances[0][i]),
"page_title": post.page_title,
"page_url": post.page_url,
"blog_url": post.blog_url,
"author": post.page_author,
"description": post.page_description,
"content_preview": content_preview,
"full_content": post.content
})
return results