Spaces:
Sleeping
Sleeping
File size: 5,262 Bytes
8cdd5f1 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 | # backend/src/api/bm25_utils.py
import re
import os
from typing import List, Dict
from rank_bm25 import BM25Okapi
from sqlalchemy import create_engine
from sqlalchemy.orm import Session, DeclarativeBase, Mapped, mapped_column
from dotenv import load_dotenv
# Import Logger
from .logging_utils import get_logger
# Load environment variables
load_dotenv()
logger = get_logger("bm25_utils")
# Tokenizer
TOKEN_RE = re.compile(r"\b\w+\b")
def tokenize(text: str) -> List[str]:
"""Tokenize text into lowercase words."""
if not text:
return []
return TOKEN_RE.findall(text.lower())
# Database Model
class Base(DeclarativeBase):
pass
class Whole_Blogs(Base):
__tablename__ = "travel_blogs"
id: Mapped[int] = mapped_column(primary_key=True)
blog_url: Mapped[str]
page_url: Mapped[str] = mapped_column(unique=True, nullable=False)
page_title: Mapped[str]
page_description: Mapped[str]
page_author: Mapped[str]
location_name: Mapped[str]
latitude: Mapped[float]
longitude: Mapped[float]
content: Mapped[str]
#embedding: Mapped[list[float]] = mapped_column(ARRAY(Float)) # new column
def __repr__(self) -> str:
return f"Whole_Blogs(id={self.id!r}, location_name={self.location_name!r}, page_title={self.page_title!r})"
# Cache for loaded data (so we don't reload from DB on every search)
_cached_posts = None
_cached_bm25 = None
def _load_blogs_from_db():
"""Load blog posts from database and build BM25 index."""
global _cached_posts, _cached_bm25
database_url = os.getenv("DATABASE_URL")
if not database_url:
logger.error("DATABASE_URL not found in environment variables")
raise ValueError("DATABASE_URL not found in environment variables")
logger.info("Loading blog posts from database...")
engine = create_engine(database_url)
posts = []
corpus = []
with Session(engine) as session:
blog_posts = session.query(Whole_Blogs).all()
logger.info(f"Loaded {len(blog_posts)} blog posts from database")
for post in blog_posts:
posts.append({
"id": post.id,
"location_name": post.location_name,
"page_title": post.page_title,
"page_description": post.page_description,
"page_author": post.page_author,
"page_url": post.page_url,
"blog_url": post.blog_url,
"latitude": post.latitude,
"longitude": post.longitude,
"content": post.content,
})
# Combine title, description, and content for searching
search_text = f"{post.page_title} {post.page_description} {post.content}"
corpus.append(search_text)
# Build BM25 index
logger.info("Building BM25 index...")
tokenized_corpus = [tokenize(doc) for doc in corpus]
bm25 = BM25Okapi(tokenized_corpus)
logger.info(f"BM25 index built with {len(posts)} documents")
# Cache the results
_cached_posts = posts
_cached_bm25 = bm25
return posts, bm25
def search_bm25(query: str, top_n: int = 12) -> List[Dict]:
"""
Search blog posts using BM25.
Args:
query: Search query string
top_n: Number of top results to return
Returns:
List of dicts with search results
"""
global _cached_posts, _cached_bm25
# Load data if not already cached
if _cached_posts is None or _cached_bm25 is None:
_load_blogs_from_db()
if not query.strip():
return []
# Tokenize query
tokenized_query = tokenize(query)
if not tokenized_query:
return []
# Get BM25 scores
scores = _cached_bm25.get_scores(tokenized_query)
# Get top N document indices
top_indices = sorted(range(len(scores)), key=lambda i: scores[i], reverse=True)[:top_n]
# Build results
results = []
for idx in top_indices:
post = _cached_posts[idx]
# Create content preview (first 300 chars)
content_preview = post.get("content", "")[:300]
if len(post.get("content", "")) > 300:
content_preview += "..."
# Try to extract country from location_name (if formatted like "City, Country")
location_parts = post.get("location_name", "").split(",")
country = location_parts[-1].strip() if len(location_parts) > 1 else ""
results.append({
"destination": post.get("location_name", "Unknown"),
"country": country,
"lat": float(post.get("latitude", 0)) if post.get("latitude") else None,
"lon": float(post.get("longitude", 0)) if post.get("longitude") else None,
"score": float(scores[idx]),
"page_title": post.get("page_title", ""),
"page_url": post.get("page_url", ""),
"blog_url": post.get("blog_url", ""),
"author": post.get("page_author", ""),
"description": post.get("page_description", ""),
"content_preview": content_preview,
"full_content": post.get('content', ""),
})
return results |