| """ |
| EmbeddingStore — Unified interface for vector storage and retrieval. |
| Wraps PostgreSQL + pgvector for CLIP embeddings and asset/brochure queries. |
| """ |
|
|
| import logging |
| from typing import List, Dict, Optional, Any |
| from uuid import UUID |
| import json |
|
|
| import numpy as np |
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| class EmbeddingStore: |
| """ |
| PostgreSQL + pgvector embedding store for the Reel Creator Platform. |
| |
| Provides: |
| - Asset metadata CRUD with vector search |
| - Brochure node storage with CLIP embeddings |
| - Brochure-to-asset mapping storage |
| - Caption and voiceover library queries |
| - Semantic search via cosine similarity |
| """ |
| |
| def __init__( |
| self, |
| db_connection_string: Optional[str] = None, |
| pool_size: int = 10, |
| ): |
| self.connection_string = db_connection_string |
| self._pool = None |
| self._clip_model = None |
| self._clip_processor = None |
| |
| def _get_connection(self): |
| """Lazy-init database connection pool.""" |
| if self._pool is None: |
| try: |
| import psycopg2 |
| from psycopg2.extras import RealDictCursor |
| |
| if self.connection_string: |
| self._pool = psycopg2.connect(self.connection_string) |
| else: |
| |
| self._pool = psycopg2.connect( |
| host="localhost", |
| port=5432, |
| dbname="reel_creator", |
| user="reel_user", |
| ) |
| except ImportError: |
| raise ImportError( |
| "PostgreSQL support requires 'psycopg2-binary'. " |
| "Install: pip install psycopg2-binary" |
| ) |
| except Exception as e: |
| logger.warning(f"Database connection not available: {e}") |
| self._pool = None |
| |
| return self._pool |
| |
| def _get_clip_model(self): |
| """Lazy-load CLIP model for text embedding.""" |
| if self._clip_model is None: |
| try: |
| import torch |
| from transformers import CLIPModel, CLIPProcessor |
| |
| model_name = "openai/clip-vit-large-patch14" |
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| |
| logger.info(f"Loading CLIP model: {model_name} on {device}") |
| self._clip_model = CLIPModel.from_pretrained(model_name).to(device) |
| self._clip_processor = CLIPProcessor.from_pretrained(model_name) |
| self._clip_device = device |
| except Exception as e: |
| logger.error(f"Failed to load CLIP model: {e}") |
| raise |
| |
| return self._clip_model, self._clip_processor |
| |
| def embed_text(self, text: str) -> List[float]: |
| """Generate CLIP text embedding for a query string.""" |
| import torch |
| |
| model, processor = self._get_clip_model() |
| |
| inputs = processor( |
| text=[text], |
| return_tensors="pt", |
| padding=True, |
| truncation=True, |
| max_length=77, |
| ).to(self._clip_device) |
| |
| with torch.no_grad(): |
| text_features = model.get_text_features(**inputs) |
| text_features = text_features / text_features.norm(dim=-1, keepdim=True) |
| |
| return text_features.cpu().numpy()[0].tolist() |
| |
| def embed_image(self, image_path: str) -> List[float]: |
| """Generate CLIP image embedding.""" |
| import torch |
| from PIL import Image |
| |
| model, processor = self._get_clip_model() |
| |
| image = Image.open(image_path).convert("RGB") |
| inputs = processor( |
| images=image, |
| return_tensors="pt", |
| ).to(self._clip_device) |
| |
| with torch.no_grad(): |
| image_features = model.get_image_features(**inputs) |
| image_features = image_features / image_features.norm(dim=-1, keepdim=True) |
| |
| return image_features.cpu().numpy()[0].tolist() |
| |
| |
| |
| |
| |
| def insert_asset(self, asset_data: Dict[str, Any]) -> Optional[UUID]: |
| """Insert an asset record.""" |
| conn = self._get_connection() |
| if not conn: |
| logger.warning("No DB connection, skipping insert") |
| return None |
| |
| try: |
| with conn.cursor() as cur: |
| cur.execute(""" |
| INSERT INTO assets (file_path, file_name, asset_type, source, |
| resolution, duration_ms, frame_rate, file_size_bytes) |
| VALUES (%s, %s, %s, %s, %s, %s, %s, %s) |
| RETURNING id |
| """, ( |
| asset_data["file_path"], |
| asset_data["file_name"], |
| asset_data["asset_type"], |
| asset_data.get("source"), |
| asset_data.get("resolution"), |
| asset_data.get("duration_ms"), |
| asset_data.get("frame_rate"), |
| asset_data.get("file_size_bytes"), |
| )) |
| result = cur.fetchone() |
| conn.commit() |
| return result[0] if result else None |
| except Exception as e: |
| logger.error(f"Error inserting asset: {e}") |
| conn.rollback() |
| return None |
| |
| def insert_asset_metadata( |
| self, |
| asset_id: UUID, |
| metadata: Dict[str, Any], |
| embedding: Optional[List[float]] = None, |
| ) -> bool: |
| """Insert asset metadata with optional CLIP embedding.""" |
| conn = self._get_connection() |
| if not conn: |
| return False |
| |
| try: |
| with conn.cursor() as cur: |
| embedding_vector = embedding if embedding else None |
| |
| cur.execute(""" |
| INSERT INTO asset_metadata |
| (asset_id, description, shot_type, camera_angle, subject_part, |
| mood, dominant_colours, confidence_score, review_flag, embedding_768) |
| VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s) |
| ON CONFLICT (asset_id) DO UPDATE SET |
| description = EXCLUDED.description, |
| shot_type = EXCLUDED.shot_type, |
| camera_angle = EXCLUDED.camera_angle, |
| subject_part = EXCLUDED.subject_part, |
| mood = EXCLUDED.mood, |
| dominant_colours = EXCLUDED.dominant_colours, |
| confidence_score = EXCLUDED.confidence_score, |
| review_flag = EXCLUDED.review_flag, |
| embedding_768 = EXCLUDED.embedding_768 |
| """, ( |
| asset_id, |
| metadata.get("description"), |
| metadata.get("shot_type"), |
| metadata.get("camera_angle"), |
| metadata.get("subject_part"), |
| metadata.get("mood"), |
| metadata.get("dominant_colours"), |
| metadata.get("confidence_score", 1.0), |
| metadata.get("review_flag", False), |
| embedding_vector, |
| )) |
| conn.commit() |
| return True |
| except Exception as e: |
| logger.error(f"Error inserting asset metadata: {e}") |
| conn.rollback() |
| return False |
| |
| def search_assets( |
| self, |
| query_embedding: List[float], |
| limit: int = 10, |
| asset_type: Optional[str] = None, |
| subject_part: Optional[str] = None, |
| mood: Optional[str] = None, |
| min_confidence: float = 0.5, |
| ) -> List[Dict[str, Any]]: |
| """Semantic search for assets using cosine similarity.""" |
| conn = self._get_connection() |
| if not conn: |
| logger.warning("No DB connection, returning empty results") |
| return [] |
| |
| try: |
| with conn.cursor() as cur: |
| query = """ |
| SELECT |
| a.id as asset_id, |
| a.file_path, |
| a.asset_type, |
| am.description, |
| am.shot_type, |
| am.camera_angle, |
| am.subject_part, |
| am.mood, |
| 1 - (am.embedding_768 <=> %s::vector) as similarity |
| FROM assets a |
| JOIN asset_metadata am ON a.id = am.asset_id |
| WHERE am.embedding_768 IS NOT NULL |
| AND am.confidence_score >= %s |
| """ |
| params = [query_embedding, min_confidence] |
| |
| if asset_type: |
| query += " AND a.asset_type = %s" |
| params.append(asset_type) |
| if subject_part: |
| query += " AND am.subject_part = %s" |
| params.append(subject_part) |
| if mood: |
| query += " AND am.mood = %s" |
| params.append(mood) |
| |
| query += " ORDER BY am.embedding_768 <=> %s::vector LIMIT %s" |
| params.extend([query_embedding, limit]) |
| |
| cur.execute(query, params) |
| columns = [desc[0] for desc in cur.description] |
| results = [] |
| for row in cur.fetchall(): |
| results.append(dict(zip(columns, row))) |
| return results |
| |
| except Exception as e: |
| logger.error(f"Error searching assets: {e}") |
| return [] |
| |
| |
| |
| |
| |
| def insert_video_event(self, event_data: Dict[str, Any]) -> Optional[UUID]: |
| """Insert a video event segment.""" |
| conn = self._get_connection() |
| if not conn: |
| return None |
| |
| try: |
| with conn.cursor() as cur: |
| cur.execute(""" |
| INSERT INTO video_events |
| (asset_id, start_ms, end_ms, description, shot_type, |
| camera_angle, subject_part, mood, embedding_768, confidence_score, keyframe_path) |
| VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s) |
| RETURNING id |
| """, ( |
| event_data["asset_id"], |
| event_data["start_ms"], |
| event_data["end_ms"], |
| event_data.get("description"), |
| event_data.get("shot_type"), |
| event_data.get("camera_angle"), |
| event_data.get("subject_part"), |
| event_data.get("mood"), |
| event_data.get("embedding_768"), |
| event_data.get("confidence_score", 1.0), |
| event_data.get("keyframe_path"), |
| )) |
| result = cur.fetchone() |
| conn.commit() |
| return result[0] if result else None |
| except Exception as e: |
| logger.error(f"Error inserting video event: {e}") |
| conn.rollback() |
| return None |
| |
| def search_video_events( |
| self, |
| query_embedding: List[float], |
| limit: int = 10, |
| subject_part: Optional[str] = None, |
| mood: Optional[str] = None, |
| ) -> List[Dict[str, Any]]: |
| """Semantic search for video events.""" |
| conn = self._get_connection() |
| if not conn: |
| return [] |
| |
| try: |
| with conn.cursor() as cur: |
| query = """ |
| SELECT |
| ve.id as event_id, |
| ve.asset_id, |
| a.file_path, |
| ve.start_ms, |
| ve.end_ms, |
| ve.duration_ms, |
| ve.description, |
| ve.subject_part, |
| ve.mood, |
| 1 - (ve.embedding_768 <=> %s::vector) as similarity |
| FROM video_events ve |
| JOIN assets a ON ve.asset_id = a.id |
| WHERE ve.embedding_768 IS NOT NULL |
| """ |
| params = [query_embedding] |
| |
| if subject_part: |
| query += " AND ve.subject_part = %s" |
| params.append(subject_part) |
| if mood: |
| query += " AND ve.mood = %s" |
| params.append(mood) |
| |
| query += " ORDER BY ve.embedding_768 <=> %s::vector LIMIT %s" |
| params.extend([query_embedding, limit]) |
| |
| cur.execute(query, params) |
| columns = [desc[0] for desc in cur.description] |
| results = [] |
| for row in cur.fetchall(): |
| results.append(dict(zip(columns, row))) |
| return results |
| |
| except Exception as e: |
| logger.error(f"Error searching video events: {e}") |
| return [] |
| |
| |
| |
| |
| |
| def insert_brochure_node(self, node) -> Optional[UUID]: |
| """Insert a brochure node with embedding.""" |
| conn = self._get_connection() |
| if not conn: |
| return None |
| |
| try: |
| with conn.cursor() as cur: |
| embedding = node.embedding_768 if node.embedding_768 else None |
| |
| cur.execute(""" |
| INSERT INTO brochure_nodes |
| (section, title, content, key_features, taglines, |
| spec_highlights, car_part_referenced, tone_tags, |
| embedding_768, page_number, source_pdf) |
| VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s) |
| RETURNING id |
| """, ( |
| node.section.value, |
| node.title, |
| node.content, |
| node.key_features, |
| node.taglines, |
| json.dumps(node.spec_highlights) if node.spec_highlights else None, |
| node.car_part_referenced, |
| [t.value for t in node.tone_tags] if node.tone_tags else None, |
| embedding, |
| node.page_number, |
| node.source_pdf, |
| )) |
| result = cur.fetchone() |
| conn.commit() |
| return result[0] if result else None |
| except Exception as e: |
| logger.error(f"Error inserting brochure node: {e}") |
| conn.rollback() |
| return None |
| |
| |
| |
| |
| |
| def insert_brochure_asset_map(self, mapping) -> bool: |
| """Insert a brochure-to-asset mapping.""" |
| conn = self._get_connection() |
| if not conn: |
| return False |
| |
| try: |
| with conn.cursor() as cur: |
| cur.execute(""" |
| INSERT INTO brochure_asset_map |
| (brochure_node_id, asset_id, video_event_id, similarity_score, |
| mapping_type, confidence_score, is_approved, reviewer_notes, rank) |
| VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s) |
| ON CONFLICT (brochure_node_id, asset_id, video_event_id) DO UPDATE SET |
| similarity_score = EXCLUDED.similarity_score, |
| confidence_score = EXCLUDED.confidence_score, |
| rank = EXCLUDED.rank, |
| updated_at = now() |
| """, ( |
| mapping.brochure_node_id, |
| mapping.asset_id, |
| mapping.video_event_id, |
| mapping.similarity_score, |
| mapping.mapping_type.value, |
| mapping.confidence_score, |
| mapping.is_approved, |
| mapping.reviewer_notes, |
| mapping.rank, |
| )) |
| conn.commit() |
| return True |
| except Exception as e: |
| logger.error(f"Error inserting brochure asset map: {e}") |
| conn.rollback() |
| return False |
| |
| def get_brochure_mapped_assets( |
| self, |
| car_part: Optional[str] = None, |
| limit: int = 10, |
| ) -> List[Dict[str, Any]]: |
| """Get assets mapped to brochure nodes mentioning a car part.""" |
| conn = self._get_connection() |
| if not conn: |
| return [] |
| |
| try: |
| with conn.cursor() as cur: |
| query = """ |
| SELECT |
| bam.brochure_node_id, |
| bam.asset_id, |
| bam.video_event_id, |
| bam.similarity_score, |
| bam.mapping_type, |
| bam.confidence_score, |
| a.file_path, |
| ve.start_ms, |
| ve.end_ms, |
| bn.car_part_referenced |
| FROM brochure_asset_map bam |
| JOIN brochure_nodes bn ON bam.brochure_node_id = bn.id |
| JOIN assets a ON bam.asset_id = a.id |
| LEFT JOIN video_events ve ON bam.video_event_id = ve.id |
| WHERE bam.is_approved IS NOT FALSE |
| """ |
| params = [] |
| |
| if car_part: |
| query += " AND %s = ANY(bn.car_part_referenced)" |
| params.append(car_part) |
| |
| query += " ORDER BY bam.confidence_score DESC LIMIT %s" |
| params.append(limit) |
| |
| cur.execute(query, params) |
| columns = [desc[0] for desc in cur.description] |
| results = [] |
| for row in cur.fetchall(): |
| results.append(dict(zip(columns, row))) |
| return results |
| |
| except Exception as e: |
| logger.error(f"Error getting brochure mapped assets: {e}") |
| return [] |
| |
| |
| |
| |
| |
| def insert_caption(self, caption) -> Optional[UUID]: |
| """Insert a caption variant.""" |
| conn = self._get_connection() |
| if not conn: |
| return None |
| |
| try: |
| with conn.cursor() as cur: |
| cur.execute(""" |
| INSERT INTO captions_library |
| (brochure_node_id, car_part, tone, duration_class, text, |
| is_brand_compliant, compliance_notes) |
| VALUES (%s, %s, %s, %s, %s, %s, %s) |
| RETURNING id |
| """, ( |
| caption.brochure_node_id, |
| caption.car_part, |
| caption.tone.value, |
| caption.duration_class.value, |
| caption.text, |
| caption.is_brand_compliant, |
| caption.compliance_notes, |
| )) |
| result = cur.fetchone() |
| conn.commit() |
| return result[0] if result else None |
| except Exception as e: |
| logger.error(f"Error inserting caption: {e}") |
| conn.rollback() |
| return None |
| |
| def insert_voiceover(self, voiceover) -> Optional[UUID]: |
| """Insert a voiceover line.""" |
| conn = self._get_connection() |
| if not conn: |
| return None |
| |
| try: |
| with conn.cursor() as cur: |
| cur.execute(""" |
| INSERT INTO voiceover_library |
| (brochure_node_id, car_part, tone, duration_class, text, |
| estimated_duration_ms, is_brand_compliant, compliance_notes) |
| VALUES (%s, %s, %s, %s, %s, %s, %s, %s) |
| RETURNING id |
| """, ( |
| voiceover.brochure_node_id, |
| voiceover.car_part, |
| voiceover.tone.value, |
| voiceover.duration_class.value, |
| voiceover.text, |
| voiceover.estimated_duration_ms, |
| voiceover.is_brand_compliant, |
| voiceover.compliance_notes, |
| )) |
| result = cur.fetchone() |
| conn.commit() |
| return result[0] if result else None |
| except Exception as e: |
| logger.error(f"Error inserting voiceover: {e}") |
| conn.rollback() |
| return None |
| |
| def query_captions( |
| self, |
| car_part: Optional[str] = None, |
| tone=None, |
| duration_class=None, |
| limit: int = 10, |
| ) -> List[Any]: |
| """Query caption library with filters.""" |
| conn = self._get_connection() |
| if not conn: |
| return [] |
| |
| try: |
| with conn.cursor() as cur: |
| query = "SELECT * FROM captions_library WHERE 1=1" |
| params = [] |
| |
| if car_part: |
| query += " AND car_part = %s" |
| params.append(car_part) |
| if tone: |
| query += " AND tone = %s" |
| params.append(tone.value if hasattr(tone, 'value') else tone) |
| if duration_class: |
| query += " AND duration_class = %s" |
| params.append(duration_class.value if hasattr(duration_class, 'value') else duration_class) |
| |
| query += " ORDER BY usage_count ASC LIMIT %s" |
| params.append(limit) |
| |
| cur.execute(query, params) |
| |
| columns = [desc[0] for desc in cur.description] |
| return [dict(zip(columns, row)) for row in cur.fetchall()] |
| |
| except Exception as e: |
| logger.error(f"Error querying captions: {e}") |
| return [] |
| |
| def query_voiceovers( |
| self, |
| car_part: Optional[str] = None, |
| tone=None, |
| duration_class=None, |
| limit: int = 10, |
| ) -> List[Any]: |
| """Query voiceover library with filters.""" |
| conn = self._get_connection() |
| if not conn: |
| return [] |
| |
| try: |
| with conn.cursor() as cur: |
| query = "SELECT * FROM voiceover_library WHERE 1=1" |
| params = [] |
| |
| if car_part: |
| query += " AND car_part = %s" |
| params.append(car_part) |
| if tone: |
| query += " AND tone = %s" |
| params.append(tone.value if hasattr(tone, 'value') else tone) |
| if duration_class: |
| query += " AND duration_class = %s" |
| params.append(duration_class.value if hasattr(duration_class, 'value') else duration_class) |
| |
| query += " ORDER BY usage_count ASC LIMIT %s" |
| params.append(limit) |
| |
| cur.execute(query, params) |
| columns = [desc[0] for desc in cur.description] |
| return [dict(zip(columns, row)) for row in cur.fetchall()] |
| |
| except Exception as e: |
| logger.error(f"Error querying voiceovers: {e}") |
| return [] |
| |
| def get_mapping_stats(self) -> Dict[str, int]: |
| """Get statistics on brochure-asset mapping approval status.""" |
| conn = self._get_connection() |
| if not conn: |
| return {} |
| |
| try: |
| with conn.cursor() as cur: |
| cur.execute(""" |
| SELECT |
| CASE |
| WHEN is_approved = true THEN 'approved' |
| WHEN is_approved = false THEN 'rejected' |
| ELSE 'pending' |
| END as status, |
| COUNT(*) as count |
| FROM brochure_asset_map |
| GROUP BY is_approved |
| """) |
| return {row[0]: row[1] for row in cur.fetchall()} |
| except Exception as e: |
| logger.error(f"Error getting mapping stats: {e}") |
| return {} |
|
|