import os import torch import faiss from sqlalchemy import create_engine from sqlalchemy.orm import Session, DeclarativeBase, Mapped, mapped_column from transformers import AutoTokenizer, AutoModel from dotenv import load_dotenv import boto3 import io load_dotenv() DEVICE = "cuda" if torch.cuda.is_available() else "cpu" MODEL_NAME = "nomic-ai/modernbert-embed-base" EMBEDDINGS_PATH = "s3://travel-recommender-s3/travel_blog_embeddings.pt" # ----------------------------- # Load model + tokenizer # ----------------------------- tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) model = AutoModel.from_pretrained(MODEL_NAME).to(DEVICE) model.eval() # ----------------------------- # Database model # ----------------------------- class Base(DeclarativeBase): pass class Whole_Blogs(Base): __tablename__ = "travel_blogs" id: Mapped[int] = mapped_column(primary_key=True) blog_url: Mapped[str] page_url: Mapped[str] page_title: Mapped[str] page_description: Mapped[str] page_author: Mapped[str] location_name: Mapped[str] latitude: Mapped[float] longitude: Mapped[float] content: Mapped[str] # ----------------------------- # Cache # ----------------------------- _cached_posts = None _index = None _embeddings = None # ----------------------------- # Embed helper for queries only # ----------------------------- def embed_texts(texts_batch): encoded = tokenizer( texts_batch, padding=True, truncation=True, max_length=512, return_tensors="pt" ).to(DEVICE) with torch.no_grad(): outputs = model(**encoded) last_hidden = outputs.last_hidden_state attention_mask = encoded["attention_mask"].unsqueeze(-1) sum_embeddings = torch.sum(last_hidden * attention_mask, dim=1) sum_mask = torch.sum(attention_mask, dim=1) sum_mask = torch.clamp(sum_mask, min=1e-9) embedding = sum_embeddings / sum_mask embedding = torch.nn.functional.normalize(embedding, p=2, dim=1) return embedding.cpu() # ----------------------------- # Load posts and embeddings # ----------------------------- def load_embeddings_from_s3(bucket: str, key: str): s3 = boto3.client("s3") obj = s3.get_object(Bucket=bucket, Key=key) buffer = io.BytesIO(obj["Body"].read()) return torch.load(buffer, weights_only=True) def _load_posts_and_index(): global _cached_posts, _index, _embeddings if _cached_posts is not None and _index is not None and _embeddings is not None: return _cached_posts, _index, _embeddings # Load metadata from DB database_url = os.getenv("DATABASE_URL") if not database_url: raise ValueError("DATABASE_URL not found in environment variables") engine = create_engine(database_url) with Session(engine) as session: _cached_posts = session.query(Whole_Blogs).all() # Load precomputed embeddings if EMBEDDINGS_PATH.startswith("s3://"): path_without_s3 = EMBEDDINGS_PATH[5:] bucket_name, key = path_without_s3.split("/", 1) data = load_embeddings_from_s3(bucket_name, key) else: data = torch.load(EMBEDDINGS_PATH, weights_only=True) # Convert embeddings dict to tensor array # Data is {blog_id: embedding_tensor, ...} # We need to convert to a stacked tensor aligned with _cached_posts order embedding_list = [] for post in _cached_posts: if post.id in data: embedding_list.append(data[post.id]) else: raise ValueError(f"Missing embedding for blog post ID {post.id}") _embeddings = torch.stack(embedding_list) # Build FAISS index _index = faiss.IndexFlatL2(_embeddings.shape[1]) _index.add(_embeddings.numpy()) return _cached_posts, _index, _embeddings # ----------------------------- # Search function # ----------------------------- def search_modernbert(query: str, top_k: int = 5): posts, index, embeddings = _load_posts_and_index() if not query.strip(): return [] q_emb = embed_texts([query]).numpy() distances, idxs = index.search(q_emb, top_k) results = [] for i, idx in enumerate(idxs[0]): post = posts[idx] content_preview = post.content[:300] + ("..." if len(post.content) > 300 else "") location_parts = post.location_name.split(",") country = location_parts[-1].strip() if len(location_parts) > 1 else "" results.append({ "destination": post.location_name, "country": country, "lat": post.latitude, "lon": post.longitude, "distance": float(distances[0][i]), "page_title": post.page_title, "page_url": post.page_url, "blog_url": post.blog_url, "author": post.page_author, "description": post.page_description, "content_preview": content_preview, "full_content": post.content }) return results