File size: 3,522 Bytes
ca7a2c2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9e98b5a
 
ca7a2c2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51ba917
ca7a2c2
 
 
 
51ba917
ca7a2c2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
"""Text RAG Tool - Semantic search in text descriptions using pgvector.

Schema: place_text_embeddings (place_id, embedding, content_type, source_text, metadata)
        places_metadata (place_id, name, category, rating, raw_data)
"""

from dataclasses import dataclass, field
from collections import defaultdict
from typing import Optional

from sqlalchemy import text
from sqlalchemy.ext.asyncio import AsyncSession

from app.shared.integrations.embedding_client import embedding_client


@dataclass
class TextSearchResult:
    """Result from text context search."""

    place_id: str
    name: str
    category: str
    rating: float
    similarity: float
    description: str = ""
    source_text: str = ""
    content_type: str = ""


# Tool definition for agent - imported from centralized prompts
from app.shared.prompts import RETRIEVE_CONTEXT_TEXT_TOOL as TOOL_DEFINITION


async def retrieve_context_text(
    db: AsyncSession,
    query: str,
    limit: int = 10,
    threshold: float = 0.3,
) -> list[TextSearchResult]:
    """
    Semantic search in text descriptions using pgvector.

    Uses place_text_embeddings table with JOIN to places_metadata.

    Args:
        db: Database session
        query: Natural language query
        limit: Maximum results
        threshold: Minimum similarity threshold

    Returns:
        List of places with similarity scores
    """
    # Generate embedding for query
    query_embedding = await embedding_client.embed_text(query)
    
    # Convert to PostgreSQL vector format
    embedding_str = "[" + ",".join(str(x) for x in query_embedding) + "]"


    # Search with JOIN to places_metadata
    # Note: Use format string for embedding since SQLAlchemy param binding 
    # doesn't work correctly with ::vector type casting
    sql = text(f"""
        SELECT DISTINCT ON (e.place_id)
            e.place_id,
            e.content_type,
            e.source_text,
            1 - (e.embedding <=> '{embedding_str}'::vector) as similarity,
            m.name,
            m.category,
            m.rating,
            m.raw_data
        FROM place_text_embeddings e
        JOIN places_metadata m ON e.place_id = m.place_id
        WHERE 1 - (e.embedding <=> '{embedding_str}'::vector) > :threshold
          AND m.name IS NOT NULL 
          AND m.name != ''
        ORDER BY e.place_id, e.embedding <=> '{embedding_str}'::vector
    """)

    results = await db.execute(sql, {
        "threshold": threshold,
    })

    rows = results.fetchall()

    # Process and score results with rating boost
    scored_results = []
    for r in rows:
        score = float(r.similarity)
        
        # Rating boost (5% for >= 4.5, 2% for >= 4.0)
        if r.rating and r.rating >= 4.5:
            score += 0.05
        elif r.rating and r.rating >= 4.0:
            score += 0.02
        
        raw_data = r.raw_data or {}
        
        scored_results.append((score, TextSearchResult(
            place_id=r.place_id,
            name=r.name or '',
            category=r.category or '',
            rating=float(r.rating) if r.rating else 0.0,
            similarity=round(score, 4),
            description=raw_data.get('description', '')[:300] if isinstance(raw_data, dict) else '',
            source_text=r.source_text[:300] if r.source_text else '',
            content_type=r.content_type or '',
        )))

    # Sort by score and limit
    scored_results.sort(key=lambda x: x[0], reverse=True)
    
    return [r for _, r in scored_results[:limit]]