object-memory / qdrant_utils /qdrant_migrate_to_dino.py
russ4stall
fresh history
24f3fb6
import time
from typing import Optional, Dict, Any
from qdrant_client.http.models import Filter, FieldCondition, MatchValue
from qdrant_client.models import PointStruct
from core.clients import get_qdrant, get_s3
from core.config import QDRANT_COLLECTION, S3_BUCKET, AWS_REGION
from core.processing import embed_image_dino_large
from core.storage import download_image_from_s3
def migrate_clip_to_dinov2(
batch_size: int = 50,
filter_condition: Optional[Filter] = None,
dry_run: bool = False,
verbose: bool = True
) -> Dict[str, Any]:
"""
Migrate existing CLIP embeddings to include DINOv2 embeddings.
Args:
batch_size: Number of points to process in each batch
filter_condition: Optional filter to limit which points to migrate
dry_run: If True, only count points without making changes
verbose: If True, print progress information
Returns:
Dictionary with migration statistics
"""
client = get_qdrant()
stats = {
"total_processed": 0,
"successful_migrations": 0,
"failed_downloads": 0,
"failed_embeddings": 0,
"failed_updates": 0,
"skipped_no_image": 0,
"skipped_has_dinov2": 0
}
# Default filter to only process image points
if filter_condition is None:
filter_condition = Filter(
must=[
FieldCondition(key="type", match=MatchValue(value="image"))
]
)
offset = None
batch_count = 0
while True:
# Scroll through points with vectors included
points, offset = client.scroll(
collection_name=QDRANT_COLLECTION,
scroll_filter=filter_condition,
limit=batch_size,
offset=offset,
with_vectors=True,
with_payload=True
)
if not points:
break
batch_count += 1
if verbose:
print(f"Processing batch {batch_count} ({len(points)} points)...")
for point in points:
stats["total_processed"] += 1
point_id = point.id
payload = point.payload
# Skip if no image URL
image_url = payload.get("image_url") or payload.get("imageUri")
if not image_url:
stats["skipped_no_image"] += 1
if verbose:
print(f" Skipping point {point_id}: no image URL")
continue
# Check if DINOv2 embedding already exists
if point.vector is None:
stats["failed_updates"] += 1
if verbose:
print(f" Skipping point {point_id}: no vectors returned")
continue
if "dinov2_embedding" in point.vector:
stats["skipped_has_dinov2"] += 1
if verbose:
print(f" Skipping point {point_id}: already has DINOv2 embedding")
continue
if dry_run:
stats["successful_migrations"] += 1
continue
try:
# Extract S3 key from URL
s3_key = image_url.replace(f"https://{S3_BUCKET}.s3.{AWS_REGION}.amazonaws.com/", "")
# Download image from S3
try:
image_np = download_image_from_s3(s3_key)
except Exception as e:
stats["failed_downloads"] += 1
if verbose:
print(f" Failed to download {s3_key}: {e}")
continue
# Generate DINOv2 embedding
try:
dinov2_embedding = embed_image_dino_large(image_np)
except Exception as e:
stats["failed_embeddings"] += 1
if verbose:
print(f" Failed to generate embedding for {point_id}: {e}")
continue
# Update the point with new vector
try:
# Get existing vectors
updated_vectors = dict(point.vector)
updated_vectors["dinov2_embedding"] = dinov2_embedding.tolist()
# Update point in Qdrant
updated_point = PointStruct(
id=point_id,
vector=updated_vectors,
payload=payload
)
client.upsert(
collection_name=QDRANT_COLLECTION,
points=[updated_point]
)
stats["successful_migrations"] += 1
if verbose:
print(f" Successfully migrated point {point_id}")
except Exception as e:
stats["failed_updates"] += 1
if verbose:
print(f" Failed to update point {point_id}: {e}")
continue
except Exception as e:
stats["failed_updates"] += 1
if verbose:
print(f" Unexpected error processing point {point_id}: {e}")
continue
# Brief pause between batches to avoid overwhelming the system
if not dry_run:
time.sleep(0.1)
# Break if no more points
if offset is None:
break
return stats
def migrate_specific_objects(object_ids: list, dry_run: bool = False, verbose: bool = True) -> Dict[str, Any]:
"""
Migrate specific objects by their object_id.
Args:
object_ids: List of object IDs to migrate
dry_run: If True, only count points without making changes
verbose: If True, print progress information
Returns:
Dictionary with migration statistics
"""
filter_condition = Filter(
must=[
FieldCondition(key="type", match=MatchValue(value="image")),
FieldCondition(key="object_id", match=MatchValue(value=object_ids))
]
)
return migrate_clip_to_dinov2(
filter_condition=filter_condition,
dry_run=dry_run,
verbose=verbose
)
def get_migration_status() -> Dict[str, int]:
"""
Get current migration status - count of points with and without DINOv2 embeddings.
Returns:
Dictionary with counts
"""
client = get_qdrant()
# Count all image points
all_image_points, _ = client.scroll(
collection_name=QDRANT_COLLECTION,
scroll_filter=Filter(
must=[FieldCondition(key="type", match=MatchValue(value="image"))]
),
limit=1,
with_payload=False,
with_vectors=False
)
# Get total count by scrolling through all
total_images = 0
has_dinov2 = 0
offset = None
while True:
points, offset = client.scroll(
collection_name=QDRANT_COLLECTION,
scroll_filter=Filter(
must=[FieldCondition(key="type", match=MatchValue(value="image"))]
),
limit=100,
offset=offset,
with_payload=False,
with_vectors=True
)
if not points:
break
for point in points:
total_images += 1
if "dinov2_embedding" in point.vector:
has_dinov2 += 1
if offset is None:
break
return {
"total_image_points": total_images,
"with_dinov2": has_dinov2,
"without_dinov2": total_images - has_dinov2,
"migration_progress": (has_dinov2 / total_images * 100) if total_images > 0 else 0
}
# Example usage functions
def run_full_migration():
"""Run a complete migration of all image points."""
print("Starting full migration...")
# First, get status
status = get_migration_status()
print(f"Migration status: {status}")
if status["without_dinov2"] == 0:
print("All points already have DINOv2 embeddings!")
return
# Run migration
stats = migrate_clip_to_dinov2(verbose=True)
print(f"\nMigration completed!")
print(f"Statistics: {stats}")
# Final status
final_status = get_migration_status()
print(f"Final status: {final_status}")
def run_dry_run():
"""Run a dry run to see what would be migrated."""
print("Running dry run...")
stats = migrate_clip_to_dinov2(dry_run=True, verbose=True)
print(f"Dry run results: {stats}")
return stats
if __name__ == "__main__":
# Example usage
run_full_migration()
#run_dry_run() # Uncomment to run a dry run