File size: 4,973 Bytes
8cdd5f1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79fc3f2
1323f77
8cdd5f1
 
 
 
79fc3f2
 
 
 
 
 
 
 
 
 
 
8cdd5f1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
import os
import torch
import faiss
from sqlalchemy import create_engine
from sqlalchemy.orm import Session, DeclarativeBase, Mapped, mapped_column
from transformers import AutoTokenizer, AutoModel
from dotenv import load_dotenv
import boto3
import io

load_dotenv()

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
MODEL_NAME = "nomic-ai/modernbert-embed-base"
EMBEDDINGS_PATH = "s3://travel-recommender-s3/travel_blog_embeddings.pt"

# -----------------------------
# Load model + tokenizer
# -----------------------------
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModel.from_pretrained(MODEL_NAME).to(DEVICE)
model.eval()

# -----------------------------
# Database model
# -----------------------------
class Base(DeclarativeBase):
    pass

class Whole_Blogs(Base):
    __tablename__ = "travel_blogs"

    id: Mapped[int] = mapped_column(primary_key=True)
    blog_url: Mapped[str]
    page_url: Mapped[str]
    page_title: Mapped[str]
    page_description: Mapped[str]
    page_author: Mapped[str]
    location_name: Mapped[str]
    latitude: Mapped[float]
    longitude: Mapped[float]
    content: Mapped[str]

# -----------------------------
# Cache
# -----------------------------
_cached_posts = None
_index = None
_embeddings = None

# -----------------------------
# Embed helper for queries only
# -----------------------------
def embed_texts(texts_batch):
    encoded = tokenizer(
        texts_batch,
        padding=True,
        truncation=True,
        max_length=512,
        return_tensors="pt"
    ).to(DEVICE)

    with torch.no_grad():
        outputs = model(**encoded)

    last_hidden = outputs.last_hidden_state
    attention_mask = encoded["attention_mask"].unsqueeze(-1)
    sum_embeddings = torch.sum(last_hidden * attention_mask, dim=1)
    sum_mask = torch.sum(attention_mask, dim=1)
    sum_mask = torch.clamp(sum_mask, min=1e-9)
    embedding = sum_embeddings / sum_mask
    embedding = torch.nn.functional.normalize(embedding, p=2, dim=1)
    return embedding.cpu()

# -----------------------------
# Load posts and embeddings
# -----------------------------
def load_embeddings_from_s3(bucket: str, key: str):
    s3 = boto3.client("s3")
    obj = s3.get_object(Bucket=bucket, Key=key)
    buffer = io.BytesIO(obj["Body"].read())
    return torch.load(buffer, weights_only=True)


def _load_posts_and_index():
    global _cached_posts, _index, _embeddings

    if _cached_posts is not None and _index is not None and _embeddings is not None:
        return _cached_posts, _index, _embeddings

    # Load metadata from DB
    database_url = os.getenv("DATABASE_URL")
    if not database_url:
        raise ValueError("DATABASE_URL not found in environment variables")
    engine = create_engine(database_url)
    with Session(engine) as session:
        _cached_posts = session.query(Whole_Blogs).all()

    # Load precomputed embeddings
    if EMBEDDINGS_PATH.startswith("s3://"):
        path_without_s3 = EMBEDDINGS_PATH[5:]
        bucket_name, key = path_without_s3.split("/", 1)
        data = load_embeddings_from_s3(bucket_name, key)
    else:
        data = torch.load(EMBEDDINGS_PATH, weights_only=True)

    # Convert embeddings dict to tensor array
    # Data is {blog_id: embedding_tensor, ...}
    # We need to convert to a stacked tensor aligned with _cached_posts order
    embedding_list = []
    for post in _cached_posts:
        if post.id in data:
            embedding_list.append(data[post.id])
        else:
            raise ValueError(f"Missing embedding for blog post ID {post.id}")
    
    _embeddings = torch.stack(embedding_list)

    # Build FAISS index
    _index = faiss.IndexFlatL2(_embeddings.shape[1])
    _index.add(_embeddings.numpy())

    return _cached_posts, _index, _embeddings

# -----------------------------
# Search function
# -----------------------------
def search_modernbert(query: str, top_k: int = 5):
    posts, index, embeddings = _load_posts_and_index()
    if not query.strip():
        return []

    q_emb = embed_texts([query]).numpy()
    distances, idxs = index.search(q_emb, top_k)

    results = []
    for i, idx in enumerate(idxs[0]):
        post = posts[idx]
        content_preview = post.content[:300] + ("..." if len(post.content) > 300 else "")
        location_parts = post.location_name.split(",")
        country = location_parts[-1].strip() if len(location_parts) > 1 else ""
        results.append({
            "destination": post.location_name,
            "country": country,
            "lat": post.latitude,
            "lon": post.longitude,
            "distance": float(distances[0][i]),
            "page_title": post.page_title,
            "page_url": post.page_url,
            "blog_url": post.blog_url,
            "author": post.page_author,
            "description": post.page_description,
            "content_preview": content_preview,
            "full_content": post.content
        })

    return results