File size: 21,770 Bytes
ba2fc46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
# import asyncio
# import json
# import requests
# import concurrent.futures
# from uuid import uuid4
# from sqlalchemy.future import select
# from qdrant_client.http import models

# # Internal Modules
# from backend.src.models.integration import UserIntegration
# from backend.src.models.ingestion import IngestionJob, JobStatus
# from backend.src.services.vector_store.qdrant_adapter import get_vector_store
# from backend.src.services.visual.engine import get_image_embedding

# # Connectors
# from backend.src.services.connectors.sanity_connector import SanityConnector
# from backend.src.services.connectors.shopify_connector import fetch_all_products as fetch_shopify
# from backend.src.services.connectors.woocommerce_connector import fetch_all_products as fetch_woo

# # --- OPTIMIZATION CONFIG ---
# BATCH_SIZE = 100 
# MAX_WORKERS = 20

# # --- πŸ”₯ SAFE LOGGING HELPER ---
# async def update_job_safe(db_factory, job_id: int, status: str, processed=0, total=0, error=None):
#     try:
#         async with db_factory() as db:
#             result = await db.execute(select(IngestionJob).where(IngestionJob.id == job_id))
#             job = result.scalars().first()
#             if job:
#                 job.status = status
#                 job.items_processed = processed
#                 job.total_items = total
#                 if error:
#                     job.error_message = str(error)
#                 await db.commit()
#     except Exception as e:
#         print(f"⚠️ Status Update Failed: {e}")

# async def fetch_products_from_source(provider: str, credentials: dict):
#     products = []
#     print(f"πŸ”„ [Visual Agent] Fetching products from {provider}...")
#     try:
#         if provider == 'sanity':
#             connector = SanityConnector(credentials)
#             query = """*[_type == "product" && defined(variants)]{
#                 _id, "slug": slug.current, "variants": variants[]{ _key, images[]{ asset->{url} } }
#             }"""
#             raw_data = connector.execute_query(query)
#             for item in raw_data:
#                 if not item.get('variants'): continue
#                 for variant in item['variants']:
#                     if not variant.get('images'): continue
#                     for img in variant['images']:
#                         if img.get('asset'):
#                             products.append({
#                                 "id": f"{item['_id']}_{variant['_key']}",
#                                 "image_path": img['asset']['url'],
#                                 "slug": item.get('slug'),
#                                 "product_id": item['_id']
#                             })
#         elif provider == 'shopify':
#             products = await asyncio.to_thread(fetch_shopify, credentials)
#         elif provider == 'woocommerce':
#             products = await asyncio.to_thread(fetch_woo, credentials)
#         return products
#     except Exception as e:
#         print(f"❌ Fetch Error: {e}")
#         return []

# def download_and_vectorize(product):
#     # Ensure we use the correct key for image path
#     image_url = product.get('image_path') or product.get('image_url')
    
#     if not image_url:
#         return None

#     try:
#         response = requests.get(image_url, timeout=5)
#         if response.status_code != 200: return None
#         image_bytes = response.content
#         vector = get_image_embedding(image_bytes)
#         if not vector: return None
#         return {"product": product, "vector": vector}
#     except Exception:
#         return None

# async def run_visual_sync(user_id: str, job_id: int, db_factory):
#     """
#     High Performance Sync: Uses ThreadPool for parallel processing.
#     """
#     print(f"πŸš€ [Visual Agent] Starting Optimized Sync Job {job_id} for User: {user_id}")
    
#     try:
#         await update_job_safe(db_factory, job_id, JobStatus.PROCESSING)

#         # 1. Credentials Fetch
#         async with db_factory() as db:
#             stmt = select(UserIntegration).where(
#                 UserIntegration.user_id == str(user_id),
#                 UserIntegration.is_active == True
#             )
#             result = await db.execute(stmt)
#             integrations = result.scalars().all()

#         qdrant_config = None
#         store_config = None
#         store_provider = None

#         for i in integrations:
#             if i.provider == 'qdrant':
#                 qdrant_config = json.loads(i.credentials)
#             elif i.provider in ['sanity', 'shopify', 'woocommerce']:
#                 store_config = json.loads(i.credentials)
#                 store_provider = i.provider

#         if not qdrant_config or not store_config:
#             await update_job_safe(db_factory, job_id, JobStatus.FAILED, error="Missing Database or Store connection.")
#             return

#         # 2. Connect Qdrant & Setup Collection
#         vector_store = get_vector_store(credentials=qdrant_config)
#         collection_name = "visual_search_products"
        
#         # Reset Collection
#         try:
#             vector_store.client.delete_collection(collection_name)
#         except: pass
        
#         vector_store.client.create_collection(
#             collection_name=collection_name,
#             vectors_config=models.VectorParams(size=2048, distance=models.Distance.COSINE)
#         )

#         # βœ… FIXED: Create Payload Index for 'user_id'
#         # Ye zaroori hai taake Qdrant filter query allow kare
#         print(f"πŸ› οΈ [Visual Agent] Creating index for user_id on {collection_name}...")
#         vector_store.client.create_payload_index(
#             collection_name=collection_name,
#             field_name="user_id",
#             field_schema=models.PayloadSchemaType.KEYWORD
#         )

#         # 3. Fetch Products
#         products = await fetch_products_from_source(store_provider, store_config)
#         total_products = len(products)
#         await update_job_safe(db_factory, job_id, JobStatus.PROCESSING, total=total_products)

#         if not products:
#             await update_job_safe(db_factory, job_id, JobStatus.COMPLETED, error="No products found.")
#             return

#         print(f"⚑ Processing {total_products} images in batches of {BATCH_SIZE}...")

#         # 4. OPTIMIZED BATCH PROCESSING
#         processed_count = 0
#         loop = asyncio.get_running_loop()
        
#         for i in range(0, total_products, BATCH_SIZE):
#             batch = products[i : i + BATCH_SIZE]
#             points = []

#             with concurrent.futures.ThreadPoolExecutor(max_workers=MAX_WORKERS) as executor:
#                 futures = [
#                     loop.run_in_executor(executor, download_and_vectorize, item) 
#                     for item in batch
#                 ]
#                 results = await asyncio.gather(*futures)

#             for res in results:
#                 if res:
#                     prod = res['product']
#                     # Use .get() to avoid KeyError if keys differ across providers
#                     img_url = prod.get('image_path') or prod.get('image_url')
                    
#                     points.append(models.PointStruct(
#                         id=str(uuid4()), 
#                         vector=res['vector'],
#                         payload={
#                             "product_id": prod.get('product_id'),
#                             "slug": prod.get('slug'),
#                             "image_url": img_url,
#                             "user_id": str(user_id),
#                             "source": store_provider
#                         }
#                     ))

#             if points:
#                 await asyncio.to_thread(
#                     vector_store.client.upsert,
#                     collection_name=collection_name, 
#                     points=points
#                 )
#                 processed_count += len(points)

#             # --- SAFE STATUS UPDATE ---
#             await update_job_safe(db_factory, job_id, JobStatus.PROCESSING, processed=processed_count, total=total_products)
#             print(f"   -> Batch {i//BATCH_SIZE + 1} done. ({processed_count}/{total_products})")

#         # Final Success
#         await update_job_safe(db_factory, job_id, JobStatus.COMPLETED, processed=processed_count, total=total_products)
#         print(f"πŸŽ‰ Job {job_id} Complete. {processed_count} images indexed.")

#     except Exception as e:
#         print(f"❌ Job {job_id} Failed: {e}")
#         await update_job_safe(db_factory, job_id, JobStatus.FAILED, error=str(e))
import asyncio
import json
import requests
import concurrent.futures
from uuid import uuid4
from sqlalchemy.future import select
from qdrant_client.http import models

# Internal Modules
from backend.src.models.integration import UserIntegration
from backend.src.models.ingestion import IngestionJob, JobStatus
from backend.src.services.vector_store.qdrant_adapter import get_vector_store
from backend.src.services.visual.engine import get_image_embedding

# Connectors
from backend.src.services.connectors.sanity_connector import SanityConnector
from backend.src.services.connectors.shopify_connector import fetch_all_products as fetch_shopify
from backend.src.services.connectors.woocommerce_connector import fetch_all_products as fetch_woo

# --- OPTIMIZATION CONFIG ---
BATCH_SIZE = 100 
MAX_WORKERS = 20

# --- πŸ”₯ SAFE LOGGING HELPER ---
async def update_job_safe(db_factory, job_id: int, status: str, processed=0, total=0, error=None, message=None):
    try:
        async with db_factory() as db:
            result = await db.execute(select(IngestionJob).where(IngestionJob.id == job_id))
            job = result.scalars().first()
            if job:
                job.status = status
                job.items_processed = processed
                job.total_items = total
                if error:
                    job.error_message = str(error)
                # Agar hum koi custom message save karna chahein
                if message:
                     print(f"πŸ“ Job Log: {message}")
                await db.commit()
    except Exception as e:
        print(f"⚠️ Status Update Failed: {e}")

async def fetch_products_from_source(provider: str, credentials: dict):
    products = []
    print(f"πŸ”„ [Visual Agent] Fetching products from {provider}...")
    try:
        if provider == 'sanity':
            connector = SanityConnector(credentials)
            query = """*[_type == "product" && defined(variants)]{
                _id, "slug": slug.current, "variants": variants[]{ _key, images[]{ asset->{url} } }
            }"""
            raw_data = connector.execute_query(query)
            for item in raw_data:
                if not item.get('variants'): continue
                for variant in item['variants']:
                    if not variant.get('images'): continue
                    for img in variant['images']:
                        if img.get('asset'):
                            products.append({
                                "id": f"{item['_id']}_{variant['_key']}",
                                "image_path": img['asset']['url'],
                                "slug": item.get('slug'),
                                "product_id": item['_id']
                            })
        elif provider == 'shopify':
            products = await asyncio.to_thread(fetch_shopify, credentials)
        elif provider == 'woocommerce':
            products = await asyncio.to_thread(fetch_woo, credentials)
        return products
    except Exception as e:
        print(f"❌ Fetch Error: {e}")
        return []

def download_and_vectorize(product):
    # Ensure we use the correct key for image path
    image_url = product.get('image_path') or product.get('image_url')
    
    if not image_url:
        return None

    try:
        response = requests.get(image_url, timeout=5)
        if response.status_code != 200: return None
        image_bytes = response.content
        vector = get_image_embedding(image_bytes)
        if not vector: return None
        return {"product": product, "vector": vector}
    except Exception:
        return None

# --- 🧠 SMART DIFF HELPER ---
async def get_current_qdrant_state(client, collection_name, user_id):
    """
    Qdrant se sirf IDs aur Image URLs fetch karta hai taake hum compare kar sakein.
    Returns: Dict { 'product_unique_id::image_url': 'qdrant_uuid' }
    """
    state = {}
    next_offset = None
    
    print(f"πŸ•΅οΈ Scanning existing Qdrant data for User: {user_id} in '{collection_name}'...")
    
    while True:
        # Scroll through points (Pagination)
        records, next_offset = await asyncio.to_thread(
            client.scroll,
            collection_name=collection_name,
            scroll_filter=models.Filter(
                must=[models.FieldCondition(key="user_id", match=models.MatchValue(value=str(user_id)))]
            ),
            limit=1000,
            with_payload=True,
            with_vectors=False, # Vector download karne ki zarurat nahi, slow hota hai
            offset=next_offset
        )
        
        for point in records:
            payload = point.payload or {}
            
            prod_id = payload.get("product_id")
            img_url = payload.get("image_url")
            
            if prod_id and img_url:
                # Composite key create karte hain uniquely identify karne ke liye
                key = f"{prod_id}::{img_url}"
                state[key] = point.id # Save Qdrant UUID (Delete karne ke kaam ayega)
        
        if next_offset is None:
            break
            
    print(f"βœ… Found {len(state)} existing records in DB.")
    return state

async def run_visual_sync(user_id: str, job_id: int, db_factory):
    """
    πŸš€ Smart Incremental Sync:
    1. Fetch Source Data
    2. Fetch DB State
    3. Calculate Diff (Add/Delete)
    4. Execute Updates
    """
    print(f"πŸš€ [Visual Agent] Starting Smart Sync Job {job_id} for User: {user_id}")
    
    try:
        await update_job_safe(db_factory, job_id, JobStatus.PROCESSING)

        # 1. Credentials Fetch
        async with db_factory() as db:
            stmt = select(UserIntegration).where(
                UserIntegration.user_id == str(user_id),
                UserIntegration.is_active == True
            )
            result = await db.execute(stmt)
            integrations = result.scalars().all()

        qdrant_config = None
        store_config = None
        store_provider = None

        for i in integrations:
            if i.provider == 'qdrant':
                qdrant_config = json.loads(i.credentials)
            elif i.provider in ['sanity', 'shopify', 'woocommerce']:
                store_config = json.loads(i.credentials)
                store_provider = i.provider

        if not qdrant_config or not store_config:
            await update_job_safe(db_factory, job_id, JobStatus.FAILED, error="Missing Database or Store connection.")
            return

        # 2. Connect Qdrant & Check Collection
        vector_store = get_vector_store(credentials=qdrant_config)
        
         # πŸ”₯ CRITICAL FIX: Explicitly look for 'visual_collection_name'
        # Agar user ne visual naam nahi diya, to default 'visual_search_products' use karo.
        # Hum 'collection_name' (jo chat ke liye hai) use NAHI karenge taake mix na ho.
        collection_name = qdrant_config.get("visual_collection_name", "visual_search_products")
        
        client = vector_store.client
        
        # Ensure Collection Exists
        if not client.collection_exists(collection_name):
             print(f"πŸ› οΈ Creating new collection: {collection_name}")
             client.create_collection(
                collection_name=collection_name,
                vectors_config=models.VectorParams(size=2048, distance=models.Distance.COSINE)
            )
             client.create_payload_index(
                collection_name=collection_name,
                field_name="user_id",
                field_schema=models.PayloadSchemaType.KEYWORD
            )

        # 3. Fetch Data from Source (Fresh List)
        source_products = await fetch_products_from_source(store_provider, store_config)
        if not source_products:
            await update_job_safe(db_factory, job_id, JobStatus.COMPLETED, error="No products found in store.")
            return

        # 4. Fetch Data from Qdrant (Existing List)
        # Map: "ProductID::ImageURL" -> QdrantUUID
        db_state = await get_current_qdrant_state(client, collection_name, user_id)

        # 5. 🧠 CALCULATE THE DIFF (The Magic)
        points_to_delete = []
        items_to_process = []
        
        # A. Identify New Items & Unchanged Items
        source_keys = set()
        
        for prod in source_products:
            prod_id = prod.get('product_id')
            img_url = prod.get('image_path') or prod.get('image_url')
            
            if not prod_id or not img_url: continue
            
            key = f"{prod_id}::{img_url}"
            source_keys.add(key)
            
            if key in db_state:
                # Already exists and Image URL is exact match -> SKIP (Save Time)
                continue
            else:
                # New Item (or URL changed, which creates a new key) -> PROCESS
                items_to_process.append(prod)

        # B. Identify Deleted Items
        # Agar koi cheez DB mein hai, lekin Source (source_keys) mein nahi, to wo delete hogi.
        for db_key, db_uuid in db_state.items():
            if db_key not in source_keys:
                points_to_delete.append(db_uuid)

        # Stats
        total_source = len(source_products)
        to_add_count = len(items_to_process)
        to_delete_count = len(points_to_delete)
        unchanged_count = total_source - to_add_count

        print(f"πŸ“Š Sync Analysis for User {user_id}:")
        print(f"   - Collection: {collection_name}")
        print(f"   - Total in Store: {total_source}")
        print(f"   - Unchanged (Skipping): {unchanged_count}")
        print(f"   - To Add/Update: {to_add_count}")
        print(f"   - To Delete (Removed from Store): {to_delete_count}")

        # 6. EXECUTE DELETE (Agar kuch delete karna ho)
        if points_to_delete:
            print(f"πŸ—‘οΈ Deleting {to_delete_count} obsolete records...")
            # Qdrant delete by Point ID (UUID)
            # Batching deletes if too many
            chunk_size = 1000
            for i in range(0, len(points_to_delete), chunk_size):
                chunk = points_to_delete[i:i + chunk_size]
                client.delete(
                    collection_name=collection_name,
                    points_selector=models.PointIdsList(points=chunk)
                )

        # 7. EXECUTE ADD/UPDATE (Batch Processing)
        if items_to_process:
            print(f"⚑ Processing {to_add_count} new images...")
            processed_count = 0
            
            # Initial status update
            await update_job_safe(db_factory, job_id, JobStatus.PROCESSING, total=to_add_count, processed=0)
            
            loop = asyncio.get_running_loop()
            
            for i in range(0, len(items_to_process), BATCH_SIZE):
                batch = items_to_process[i : i + BATCH_SIZE]
                points = []

                # Parallel Download & Vectorize
                with concurrent.futures.ThreadPoolExecutor(max_workers=MAX_WORKERS) as executor:
                    futures = [
                        loop.run_in_executor(executor, download_and_vectorize, item) 
                        for item in batch
                    ]
                    results = await asyncio.gather(*futures)

                for res in results:
                    if res:
                        prod = res['product']
                        img_url = prod.get('image_path') or prod.get('image_url')
                        
                        points.append(models.PointStruct(
                            id=str(uuid4()), 
                            vector=res['vector'],
                            payload={
                                "product_id": prod.get('product_id'),
                                "slug": prod.get('slug'),
                                "image_url": img_url,
                                "user_id": str(user_id),
                                "source": store_provider
                            }
                        ))

                if points:
                    await asyncio.to_thread(
                        client.upsert,
                        collection_name=collection_name, 
                        points=points
                    )
                    processed_count += len(points)

                # Progress Update
                await update_job_safe(db_factory, job_id, JobStatus.PROCESSING, processed=processed_count, total=to_add_count)
                print(f"   -> Batch {i//BATCH_SIZE + 1} done. ({processed_count}/{to_add_count})")
        else:
            print("✨ No new images to process.")

        # Final Success
        final_msg = f"Sync Complete. Added: {to_add_count}, Deleted: {to_delete_count}, Skipped: {unchanged_count}"
        await update_job_safe(db_factory, job_id, JobStatus.COMPLETED, processed=to_add_count, total=to_add_count, message=final_msg)
        print(f"πŸŽ‰ {final_msg}")

    except Exception as e:
        print(f"❌ Job {job_id} Failed: {e}")
        await update_job_safe(db_factory, job_id, JobStatus.FAILED, error=str(e))