File size: 12,923 Bytes
1d32142
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
"""Vector storage and retrieval for donor/volunteer embeddings.

Uses the existing my_embeddings table in Supabase with pgvector extension.
"""

import json
from typing import List, Optional, Dict, Any, Union
from dataclasses import dataclass
import numpy as np


def _parse_json_field(value: Union[str, dict, None]) -> dict:
    """Safely parse a JSON field that might already be a dict (psycopg3 auto-parses)."""
    if value is None:
        return {}
    if isinstance(value, dict):
        return value
    if isinstance(value, str):
        try:
            return json.loads(value)
        except json.JSONDecodeError:
            return {}
    return {}


@dataclass
class SimilarityResult:
    """Result from similarity search.

    Attributes:
        id: The source_id of the matched form.
        form_data: The original form data as a dictionary.
        score: Similarity score (higher is more similar).
        form_type: Type of form ("donor" or "volunteer").
        distance: Raw L2 distance from query.
    """
    id: str
    form_data: Dict[str, Any]
    score: float
    form_type: str
    distance: float = 0.0


class DonorVectorStore:
    """Vector storage and retrieval for donor/volunteer embeddings.

    Uses the existing my_embeddings table schema:
    - source_id: form ID
    - chunk_index: always 0 (single embedding per form)
    - text_content: JSON serialized form data
    - metadata: {"form_type": "donor"|"volunteer", ...}
    - embedding: VECTOR(1024)

    Attributes:
        pool: AsyncConnectionPool for database connections.
    """

    def __init__(self, pool):
        """Initialize vector store.

        Args:
            pool: AsyncConnectionPool from psycopg_pool
        """
        self.pool = pool

    async def store_embedding(
        self,
        form_id: str,
        form_type: str,
        embedding: np.ndarray,
        form_data: Dict[str, Any]
    ) -> int:
        """Store form embedding in my_embeddings table.

        Args:
            form_id: Unique identifier for the form.
            form_type: Type of form ("donor" or "volunteer").
            embedding: The 1024-dimensional embedding vector.
            form_data: Original form data to store.

        Returns:
            The database ID of the inserted record.
        """
        embedding_list = embedding.tolist()
        form_json = json.dumps(form_data, default=str)

        async with self.pool.connection() as conn:
            async with conn.cursor() as cur:
                await cur.execute(
                    """
                    INSERT INTO my_embeddings
                    (source_id, chunk_index, text_content, metadata, embedding)
                    VALUES (%s, %s, %s, %s, %s::vector)
                    RETURNING id
                    """,
                    (
                        form_id,
                        0,  # Single embedding per form
                        form_json,
                        json.dumps({"form_type": form_type}),
                        embedding_list
                    )
                )
                result = await cur.fetchone()
                return result[0]

    async def update_embedding(
        self,
        form_id: str,
        embedding: np.ndarray,
        form_data: Optional[Dict[str, Any]] = None
    ) -> bool:
        """Update an existing embedding.

        Args:
            form_id: The form ID to update.
            embedding: New embedding vector.
            form_data: Optional updated form data.

        Returns:
            True if update succeeded, False if record not found.
        """
        embedding_list = embedding.tolist()

        async with self.pool.connection() as conn:
            async with conn.cursor() as cur:
                if form_data:
                    form_json = json.dumps(form_data, default=str)
                    await cur.execute(
                        """
                        UPDATE my_embeddings
                        SET embedding = %s::vector, text_content = %s
                        WHERE source_id = %s
                        """,
                        (embedding_list, form_json, form_id)
                    )
                else:
                    await cur.execute(
                        """
                        UPDATE my_embeddings
                        SET embedding = %s::vector
                        WHERE source_id = %s
                        """,
                        (embedding_list, form_id)
                    )
                return cur.rowcount > 0

    async def delete_embedding(self, form_id: str) -> bool:
        """Delete an embedding by form ID.

        Args:
            form_id: The form ID to delete.

        Returns:
            True if deletion succeeded, False if record not found.
        """
        async with self.pool.connection() as conn:
            async with conn.cursor() as cur:
                await cur.execute(
                    "DELETE FROM my_embeddings WHERE source_id = %s",
                    (form_id,)
                )
                return cur.rowcount > 0

    async def get_embedding(self, form_id: str) -> Optional[SimilarityResult]:
        """Get a specific embedding by form ID.

        Args:
            form_id: The form ID to retrieve.

        Returns:
            SimilarityResult if found, None otherwise.
        """
        async with self.pool.connection() as conn:
            async with conn.cursor() as cur:
                await cur.execute(
                    """
                    SELECT source_id, text_content, metadata
                    FROM my_embeddings
                    WHERE source_id = %s
                    """,
                    (form_id,)
                )
                row = await cur.fetchone()

                if not row:
                    return None

                form_data = _parse_json_field(row[1])
                metadata = _parse_json_field(row[2])

                return SimilarityResult(
                    id=row[0],
                    form_data=form_data,
                    form_type=metadata.get("form_type", "unknown"),
                    score=1.0,
                    distance=0.0,
                )

    async def find_similar(
        self,
        query_embedding: np.ndarray,
        form_type: Optional[str] = None,
        limit: int = 10,
        country_filter: Optional[str] = None,
        exclude_ids: Optional[List[str]] = None
    ) -> List[SimilarityResult]:
        """Find similar donors/volunteers using vector similarity.

        Uses L2 distance (Euclidean) with IVFFlat index for efficient search.

        Args:
            query_embedding: The query embedding vector.
            form_type: Optional filter for "donor" or "volunteer".
            limit: Maximum number of results to return.
            country_filter: Optional filter for country code.
            exclude_ids: Optional list of form IDs to exclude.

        Returns:
            List of SimilarityResult ordered by similarity (highest first).
        """
        embedding_list = query_embedding.tolist()

        # Build query with optional filters
        query = """
            SELECT
                source_id,
                text_content,
                metadata,
                embedding <-> %s::vector AS distance
            FROM my_embeddings
            WHERE 1=1
        """
        params: List[Any] = [embedding_list]

        if form_type:
            query += " AND metadata->>'form_type' = %s"
            params.append(form_type)

        if country_filter:
            query += " AND text_content ILIKE %s"
            params.append(f'%"country": "{country_filter}"%')

        if exclude_ids:
            placeholders = ", ".join(["%s"] * len(exclude_ids))
            query += f" AND source_id NOT IN ({placeholders})"
            params.extend(exclude_ids)

        query += " ORDER BY distance ASC LIMIT %s"
        params.append(limit)

        async with self.pool.connection() as conn:
            async with conn.cursor() as cur:
                await cur.execute(query, params)
                rows = await cur.fetchall()

        results = []
        for row in rows:
            form_data = _parse_json_field(row[1])
            metadata = _parse_json_field(row[2])
            distance = float(row[3])

            results.append(SimilarityResult(
                id=row[0],
                form_data=form_data,
                form_type=metadata.get("form_type", "unknown"),
                score=1.0 / (1.0 + distance),  # Convert distance to similarity
                distance=distance
            ))

        return results

    async def find_by_causes(
        self,
        target_causes: List[str],
        query_embedding: np.ndarray,
        limit: int = 20
    ) -> List[SimilarityResult]:
        """Hybrid search: filter by causes, rank by embedding similarity.

        Combines keyword filtering with vector similarity for better
        recommendations when specific causes are targeted.

        Args:
            target_causes: List of cause categories to match.
            query_embedding: The query embedding for ranking.
            limit: Maximum number of results to return.

        Returns:
            List of SimilarityResult matching causes, ranked by similarity.
        """
        embedding_list = query_embedding.tolist()

        # Build ILIKE clauses for cause filtering
        cause_conditions = " OR ".join([
            "text_content ILIKE %s" for _ in target_causes
        ])
        cause_params = [f"%{cause}%" for cause in target_causes]

        query = f"""
            SELECT
                source_id,
                text_content,
                metadata,
                embedding <-> %s::vector AS distance
            FROM my_embeddings
            WHERE ({cause_conditions})
            ORDER BY distance ASC
            LIMIT %s
        """

        params = [embedding_list] + cause_params + [limit]

        async with self.pool.connection() as conn:
            async with conn.cursor() as cur:
                await cur.execute(query, params)
                rows = await cur.fetchall()

        results = []
        for row in rows:
            form_data = _parse_json_field(row[1])
            metadata = _parse_json_field(row[2])
            distance = float(row[3])

            results.append(SimilarityResult(
                id=row[0],
                form_data=form_data,
                form_type=metadata.get("form_type", "unknown"),
                score=1.0 / (1.0 + distance),
                distance=distance
            ))

        return results

    async def count_by_type(self) -> Dict[str, int]:
        """Get count of embeddings by form type.

        Returns:
            Dictionary with counts: {"donor": N, "volunteer": M, "total": N+M}
        """
        async with self.pool.connection() as conn:
            async with conn.cursor() as cur:
                await cur.execute("""
                    SELECT
                        metadata->>'form_type' as form_type,
                        COUNT(*) as count
                    FROM my_embeddings
                    GROUP BY metadata->>'form_type'
                """)
                rows = await cur.fetchall()

        counts = {"donor": 0, "volunteer": 0, "total": 0}
        for row in rows:
            form_type = row[0] or "unknown"
            count = row[1]
            if form_type in counts:
                counts[form_type] = count
            counts["total"] += count

        return counts

    async def find_by_form_type(
        self, form_type: str, limit: int = 500
    ) -> List[SimilarityResult]:
        """Get all entries of a specific form type.

        Args:
            form_type: Type of form ("donor", "volunteer", or "client").
            limit: Maximum number of results to return.

        Returns:
            List of SimilarityResult for the specified form type.
        """
        query = """
            SELECT
                source_id,
                text_content,
                metadata
            FROM my_embeddings
            WHERE metadata->>'form_type' = %s
            LIMIT %s
        """

        async with self.pool.connection() as conn:
            async with conn.cursor() as cur:
                await cur.execute(query, (form_type, limit))
                rows = await cur.fetchall()

        results = []
        for row in rows:
            form_data = _parse_json_field(row[1])
            metadata = _parse_json_field(row[2])

            results.append(
                SimilarityResult(
                    id=row[0],
                    form_data=form_data,
                    form_type=metadata.get("form_type", form_type),
                    score=1.0,
                    distance=0.0,
                )
            )

        return results