Spaces:
Sleeping
Sleeping
File size: 4,973 Bytes
8cdd5f1 79fc3f2 1323f77 8cdd5f1 79fc3f2 8cdd5f1 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 | import os
import torch
import faiss
from sqlalchemy import create_engine
from sqlalchemy.orm import Session, DeclarativeBase, Mapped, mapped_column
from transformers import AutoTokenizer, AutoModel
from dotenv import load_dotenv
import boto3
import io
load_dotenv()
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
MODEL_NAME = "nomic-ai/modernbert-embed-base"
EMBEDDINGS_PATH = "s3://travel-recommender-s3/travel_blog_embeddings.pt"
# -----------------------------
# Load model + tokenizer
# -----------------------------
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModel.from_pretrained(MODEL_NAME).to(DEVICE)
model.eval()
# -----------------------------
# Database model
# -----------------------------
class Base(DeclarativeBase):
pass
class Whole_Blogs(Base):
__tablename__ = "travel_blogs"
id: Mapped[int] = mapped_column(primary_key=True)
blog_url: Mapped[str]
page_url: Mapped[str]
page_title: Mapped[str]
page_description: Mapped[str]
page_author: Mapped[str]
location_name: Mapped[str]
latitude: Mapped[float]
longitude: Mapped[float]
content: Mapped[str]
# -----------------------------
# Cache
# -----------------------------
_cached_posts = None
_index = None
_embeddings = None
# -----------------------------
# Embed helper for queries only
# -----------------------------
def embed_texts(texts_batch):
encoded = tokenizer(
texts_batch,
padding=True,
truncation=True,
max_length=512,
return_tensors="pt"
).to(DEVICE)
with torch.no_grad():
outputs = model(**encoded)
last_hidden = outputs.last_hidden_state
attention_mask = encoded["attention_mask"].unsqueeze(-1)
sum_embeddings = torch.sum(last_hidden * attention_mask, dim=1)
sum_mask = torch.sum(attention_mask, dim=1)
sum_mask = torch.clamp(sum_mask, min=1e-9)
embedding = sum_embeddings / sum_mask
embedding = torch.nn.functional.normalize(embedding, p=2, dim=1)
return embedding.cpu()
# -----------------------------
# Load posts and embeddings
# -----------------------------
def load_embeddings_from_s3(bucket: str, key: str):
s3 = boto3.client("s3")
obj = s3.get_object(Bucket=bucket, Key=key)
buffer = io.BytesIO(obj["Body"].read())
return torch.load(buffer, weights_only=True)
def _load_posts_and_index():
global _cached_posts, _index, _embeddings
if _cached_posts is not None and _index is not None and _embeddings is not None:
return _cached_posts, _index, _embeddings
# Load metadata from DB
database_url = os.getenv("DATABASE_URL")
if not database_url:
raise ValueError("DATABASE_URL not found in environment variables")
engine = create_engine(database_url)
with Session(engine) as session:
_cached_posts = session.query(Whole_Blogs).all()
# Load precomputed embeddings
if EMBEDDINGS_PATH.startswith("s3://"):
path_without_s3 = EMBEDDINGS_PATH[5:]
bucket_name, key = path_without_s3.split("/", 1)
data = load_embeddings_from_s3(bucket_name, key)
else:
data = torch.load(EMBEDDINGS_PATH, weights_only=True)
# Convert embeddings dict to tensor array
# Data is {blog_id: embedding_tensor, ...}
# We need to convert to a stacked tensor aligned with _cached_posts order
embedding_list = []
for post in _cached_posts:
if post.id in data:
embedding_list.append(data[post.id])
else:
raise ValueError(f"Missing embedding for blog post ID {post.id}")
_embeddings = torch.stack(embedding_list)
# Build FAISS index
_index = faiss.IndexFlatL2(_embeddings.shape[1])
_index.add(_embeddings.numpy())
return _cached_posts, _index, _embeddings
# -----------------------------
# Search function
# -----------------------------
def search_modernbert(query: str, top_k: int = 5):
posts, index, embeddings = _load_posts_and_index()
if not query.strip():
return []
q_emb = embed_texts([query]).numpy()
distances, idxs = index.search(q_emb, top_k)
results = []
for i, idx in enumerate(idxs[0]):
post = posts[idx]
content_preview = post.content[:300] + ("..." if len(post.content) > 300 else "")
location_parts = post.location_name.split(",")
country = location_parts[-1].strip() if len(location_parts) > 1 else ""
results.append({
"destination": post.location_name,
"country": country,
"lat": post.latitude,
"lon": post.longitude,
"distance": float(distances[0][i]),
"page_title": post.page_title,
"page_url": post.page_url,
"blog_url": post.blog_url,
"author": post.page_author,
"description": post.page_description,
"content_preview": content_preview,
"full_content": post.content
})
return results
|