Spaces:
Configuration error
Configuration error
| import time | |
| from typing import Optional, Dict, Any | |
| from qdrant_client.http.models import Filter, FieldCondition, MatchValue | |
| from qdrant_client.models import PointStruct | |
| from core.clients import get_qdrant, get_s3 | |
| from core.config import QDRANT_COLLECTION, S3_BUCKET, AWS_REGION | |
| from core.processing import embed_image_dino_large | |
| from core.storage import download_image_from_s3 | |
| def migrate_clip_to_dinov2( | |
| batch_size: int = 50, | |
| filter_condition: Optional[Filter] = None, | |
| dry_run: bool = False, | |
| verbose: bool = True | |
| ) -> Dict[str, Any]: | |
| """ | |
| Migrate existing CLIP embeddings to include DINOv2 embeddings. | |
| Args: | |
| batch_size: Number of points to process in each batch | |
| filter_condition: Optional filter to limit which points to migrate | |
| dry_run: If True, only count points without making changes | |
| verbose: If True, print progress information | |
| Returns: | |
| Dictionary with migration statistics | |
| """ | |
| client = get_qdrant() | |
| stats = { | |
| "total_processed": 0, | |
| "successful_migrations": 0, | |
| "failed_downloads": 0, | |
| "failed_embeddings": 0, | |
| "failed_updates": 0, | |
| "skipped_no_image": 0, | |
| "skipped_has_dinov2": 0 | |
| } | |
| # Default filter to only process image points | |
| if filter_condition is None: | |
| filter_condition = Filter( | |
| must=[ | |
| FieldCondition(key="type", match=MatchValue(value="image")) | |
| ] | |
| ) | |
| offset = None | |
| batch_count = 0 | |
| while True: | |
| # Scroll through points with vectors included | |
| points, offset = client.scroll( | |
| collection_name=QDRANT_COLLECTION, | |
| scroll_filter=filter_condition, | |
| limit=batch_size, | |
| offset=offset, | |
| with_vectors=True, | |
| with_payload=True | |
| ) | |
| if not points: | |
| break | |
| batch_count += 1 | |
| if verbose: | |
| print(f"Processing batch {batch_count} ({len(points)} points)...") | |
| for point in points: | |
| stats["total_processed"] += 1 | |
| point_id = point.id | |
| payload = point.payload | |
| # Skip if no image URL | |
| image_url = payload.get("image_url") or payload.get("imageUri") | |
| if not image_url: | |
| stats["skipped_no_image"] += 1 | |
| if verbose: | |
| print(f" Skipping point {point_id}: no image URL") | |
| continue | |
| # Check if DINOv2 embedding already exists | |
| if point.vector is None: | |
| stats["failed_updates"] += 1 | |
| if verbose: | |
| print(f" Skipping point {point_id}: no vectors returned") | |
| continue | |
| if "dinov2_embedding" in point.vector: | |
| stats["skipped_has_dinov2"] += 1 | |
| if verbose: | |
| print(f" Skipping point {point_id}: already has DINOv2 embedding") | |
| continue | |
| if dry_run: | |
| stats["successful_migrations"] += 1 | |
| continue | |
| try: | |
| # Extract S3 key from URL | |
| s3_key = image_url.replace(f"https://{S3_BUCKET}.s3.{AWS_REGION}.amazonaws.com/", "") | |
| # Download image from S3 | |
| try: | |
| image_np = download_image_from_s3(s3_key) | |
| except Exception as e: | |
| stats["failed_downloads"] += 1 | |
| if verbose: | |
| print(f" Failed to download {s3_key}: {e}") | |
| continue | |
| # Generate DINOv2 embedding | |
| try: | |
| dinov2_embedding = embed_image_dino_large(image_np) | |
| except Exception as e: | |
| stats["failed_embeddings"] += 1 | |
| if verbose: | |
| print(f" Failed to generate embedding for {point_id}: {e}") | |
| continue | |
| # Update the point with new vector | |
| try: | |
| # Get existing vectors | |
| updated_vectors = dict(point.vector) | |
| updated_vectors["dinov2_embedding"] = dinov2_embedding.tolist() | |
| # Update point in Qdrant | |
| updated_point = PointStruct( | |
| id=point_id, | |
| vector=updated_vectors, | |
| payload=payload | |
| ) | |
| client.upsert( | |
| collection_name=QDRANT_COLLECTION, | |
| points=[updated_point] | |
| ) | |
| stats["successful_migrations"] += 1 | |
| if verbose: | |
| print(f" Successfully migrated point {point_id}") | |
| except Exception as e: | |
| stats["failed_updates"] += 1 | |
| if verbose: | |
| print(f" Failed to update point {point_id}: {e}") | |
| continue | |
| except Exception as e: | |
| stats["failed_updates"] += 1 | |
| if verbose: | |
| print(f" Unexpected error processing point {point_id}: {e}") | |
| continue | |
| # Brief pause between batches to avoid overwhelming the system | |
| if not dry_run: | |
| time.sleep(0.1) | |
| # Break if no more points | |
| if offset is None: | |
| break | |
| return stats | |
| def migrate_specific_objects(object_ids: list, dry_run: bool = False, verbose: bool = True) -> Dict[str, Any]: | |
| """ | |
| Migrate specific objects by their object_id. | |
| Args: | |
| object_ids: List of object IDs to migrate | |
| dry_run: If True, only count points without making changes | |
| verbose: If True, print progress information | |
| Returns: | |
| Dictionary with migration statistics | |
| """ | |
| filter_condition = Filter( | |
| must=[ | |
| FieldCondition(key="type", match=MatchValue(value="image")), | |
| FieldCondition(key="object_id", match=MatchValue(value=object_ids)) | |
| ] | |
| ) | |
| return migrate_clip_to_dinov2( | |
| filter_condition=filter_condition, | |
| dry_run=dry_run, | |
| verbose=verbose | |
| ) | |
| def get_migration_status() -> Dict[str, int]: | |
| """ | |
| Get current migration status - count of points with and without DINOv2 embeddings. | |
| Returns: | |
| Dictionary with counts | |
| """ | |
| client = get_qdrant() | |
| # Count all image points | |
| all_image_points, _ = client.scroll( | |
| collection_name=QDRANT_COLLECTION, | |
| scroll_filter=Filter( | |
| must=[FieldCondition(key="type", match=MatchValue(value="image"))] | |
| ), | |
| limit=1, | |
| with_payload=False, | |
| with_vectors=False | |
| ) | |
| # Get total count by scrolling through all | |
| total_images = 0 | |
| has_dinov2 = 0 | |
| offset = None | |
| while True: | |
| points, offset = client.scroll( | |
| collection_name=QDRANT_COLLECTION, | |
| scroll_filter=Filter( | |
| must=[FieldCondition(key="type", match=MatchValue(value="image"))] | |
| ), | |
| limit=100, | |
| offset=offset, | |
| with_payload=False, | |
| with_vectors=True | |
| ) | |
| if not points: | |
| break | |
| for point in points: | |
| total_images += 1 | |
| if "dinov2_embedding" in point.vector: | |
| has_dinov2 += 1 | |
| if offset is None: | |
| break | |
| return { | |
| "total_image_points": total_images, | |
| "with_dinov2": has_dinov2, | |
| "without_dinov2": total_images - has_dinov2, | |
| "migration_progress": (has_dinov2 / total_images * 100) if total_images > 0 else 0 | |
| } | |
| # Example usage functions | |
| def run_full_migration(): | |
| """Run a complete migration of all image points.""" | |
| print("Starting full migration...") | |
| # First, get status | |
| status = get_migration_status() | |
| print(f"Migration status: {status}") | |
| if status["without_dinov2"] == 0: | |
| print("All points already have DINOv2 embeddings!") | |
| return | |
| # Run migration | |
| stats = migrate_clip_to_dinov2(verbose=True) | |
| print(f"\nMigration completed!") | |
| print(f"Statistics: {stats}") | |
| # Final status | |
| final_status = get_migration_status() | |
| print(f"Final status: {final_status}") | |
| def run_dry_run(): | |
| """Run a dry run to see what would be migrated.""" | |
| print("Running dry run...") | |
| stats = migrate_clip_to_dinov2(dry_run=True, verbose=True) | |
| print(f"Dry run results: {stats}") | |
| return stats | |
| if __name__ == "__main__": | |
| # Example usage | |
| run_full_migration() | |
| #run_dry_run() # Uncomment to run a dry run |