""" EmbeddingStore — Unified interface for vector storage and retrieval. Wraps PostgreSQL + pgvector for CLIP embeddings and asset/brochure queries. """ import logging from typing import List, Dict, Optional, Any from uuid import UUID import json import numpy as np logger = logging.getLogger(__name__) class EmbeddingStore: """ PostgreSQL + pgvector embedding store for the Reel Creator Platform. Provides: - Asset metadata CRUD with vector search - Brochure node storage with CLIP embeddings - Brochure-to-asset mapping storage - Caption and voiceover library queries - Semantic search via cosine similarity """ def __init__( self, db_connection_string: Optional[str] = None, pool_size: int = 10, ): self.connection_string = db_connection_string self._pool = None self._clip_model = None self._clip_processor = None def _get_connection(self): """Lazy-init database connection pool.""" if self._pool is None: try: import psycopg2 from psycopg2.extras import RealDictCursor if self.connection_string: self._pool = psycopg2.connect(self.connection_string) else: # Default local connection self._pool = psycopg2.connect( host="localhost", port=5432, dbname="reel_creator", user="reel_user", ) except ImportError: raise ImportError( "PostgreSQL support requires 'psycopg2-binary'. " "Install: pip install psycopg2-binary" ) except Exception as e: logger.warning(f"Database connection not available: {e}") self._pool = None return self._pool def _get_clip_model(self): """Lazy-load CLIP model for text embedding.""" if self._clip_model is None: try: import torch from transformers import CLIPModel, CLIPProcessor model_name = "openai/clip-vit-large-patch14" device = "cuda" if torch.cuda.is_available() else "cpu" logger.info(f"Loading CLIP model: {model_name} on {device}") self._clip_model = CLIPModel.from_pretrained(model_name).to(device) self._clip_processor = CLIPProcessor.from_pretrained(model_name) self._clip_device = device except Exception as e: logger.error(f"Failed to load CLIP model: {e}") raise return self._clip_model, self._clip_processor def embed_text(self, text: str) -> List[float]: """Generate CLIP text embedding for a query string.""" import torch model, processor = self._get_clip_model() inputs = processor( text=[text], return_tensors="pt", padding=True, truncation=True, max_length=77, ).to(self._clip_device) with torch.no_grad(): text_features = model.get_text_features(**inputs) text_features = text_features / text_features.norm(dim=-1, keepdim=True) return text_features.cpu().numpy()[0].tolist() def embed_image(self, image_path: str) -> List[float]: """Generate CLIP image embedding.""" import torch from PIL import Image model, processor = self._get_clip_model() image = Image.open(image_path).convert("RGB") inputs = processor( images=image, return_tensors="pt", ).to(self._clip_device) with torch.no_grad(): image_features = model.get_image_features(**inputs) image_features = image_features / image_features.norm(dim=-1, keepdim=True) return image_features.cpu().numpy()[0].tolist() # ============================================================ # ASSET OPERATIONS # ============================================================ def insert_asset(self, asset_data: Dict[str, Any]) -> Optional[UUID]: """Insert an asset record.""" conn = self._get_connection() if not conn: logger.warning("No DB connection, skipping insert") return None try: with conn.cursor() as cur: cur.execute(""" INSERT INTO assets (file_path, file_name, asset_type, source, resolution, duration_ms, frame_rate, file_size_bytes) VALUES (%s, %s, %s, %s, %s, %s, %s, %s) RETURNING id """, ( asset_data["file_path"], asset_data["file_name"], asset_data["asset_type"], asset_data.get("source"), asset_data.get("resolution"), asset_data.get("duration_ms"), asset_data.get("frame_rate"), asset_data.get("file_size_bytes"), )) result = cur.fetchone() conn.commit() return result[0] if result else None except Exception as e: logger.error(f"Error inserting asset: {e}") conn.rollback() return None def insert_asset_metadata( self, asset_id: UUID, metadata: Dict[str, Any], embedding: Optional[List[float]] = None, ) -> bool: """Insert asset metadata with optional CLIP embedding.""" conn = self._get_connection() if not conn: return False try: with conn.cursor() as cur: embedding_vector = embedding if embedding else None cur.execute(""" INSERT INTO asset_metadata (asset_id, description, shot_type, camera_angle, subject_part, mood, dominant_colours, confidence_score, review_flag, embedding_768) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s) ON CONFLICT (asset_id) DO UPDATE SET description = EXCLUDED.description, shot_type = EXCLUDED.shot_type, camera_angle = EXCLUDED.camera_angle, subject_part = EXCLUDED.subject_part, mood = EXCLUDED.mood, dominant_colours = EXCLUDED.dominant_colours, confidence_score = EXCLUDED.confidence_score, review_flag = EXCLUDED.review_flag, embedding_768 = EXCLUDED.embedding_768 """, ( asset_id, metadata.get("description"), metadata.get("shot_type"), metadata.get("camera_angle"), metadata.get("subject_part"), metadata.get("mood"), metadata.get("dominant_colours"), metadata.get("confidence_score", 1.0), metadata.get("review_flag", False), embedding_vector, )) conn.commit() return True except Exception as e: logger.error(f"Error inserting asset metadata: {e}") conn.rollback() return False def search_assets( self, query_embedding: List[float], limit: int = 10, asset_type: Optional[str] = None, subject_part: Optional[str] = None, mood: Optional[str] = None, min_confidence: float = 0.5, ) -> List[Dict[str, Any]]: """Semantic search for assets using cosine similarity.""" conn = self._get_connection() if not conn: logger.warning("No DB connection, returning empty results") return [] try: with conn.cursor() as cur: query = """ SELECT a.id as asset_id, a.file_path, a.asset_type, am.description, am.shot_type, am.camera_angle, am.subject_part, am.mood, 1 - (am.embedding_768 <=> %s::vector) as similarity FROM assets a JOIN asset_metadata am ON a.id = am.asset_id WHERE am.embedding_768 IS NOT NULL AND am.confidence_score >= %s """ params = [query_embedding, min_confidence] if asset_type: query += " AND a.asset_type = %s" params.append(asset_type) if subject_part: query += " AND am.subject_part = %s" params.append(subject_part) if mood: query += " AND am.mood = %s" params.append(mood) query += " ORDER BY am.embedding_768 <=> %s::vector LIMIT %s" params.extend([query_embedding, limit]) cur.execute(query, params) columns = [desc[0] for desc in cur.description] results = [] for row in cur.fetchall(): results.append(dict(zip(columns, row))) return results except Exception as e: logger.error(f"Error searching assets: {e}") return [] # ============================================================ # VIDEO EVENT OPERATIONS # ============================================================ def insert_video_event(self, event_data: Dict[str, Any]) -> Optional[UUID]: """Insert a video event segment.""" conn = self._get_connection() if not conn: return None try: with conn.cursor() as cur: cur.execute(""" INSERT INTO video_events (asset_id, start_ms, end_ms, description, shot_type, camera_angle, subject_part, mood, embedding_768, confidence_score, keyframe_path) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s) RETURNING id """, ( event_data["asset_id"], event_data["start_ms"], event_data["end_ms"], event_data.get("description"), event_data.get("shot_type"), event_data.get("camera_angle"), event_data.get("subject_part"), event_data.get("mood"), event_data.get("embedding_768"), event_data.get("confidence_score", 1.0), event_data.get("keyframe_path"), )) result = cur.fetchone() conn.commit() return result[0] if result else None except Exception as e: logger.error(f"Error inserting video event: {e}") conn.rollback() return None def search_video_events( self, query_embedding: List[float], limit: int = 10, subject_part: Optional[str] = None, mood: Optional[str] = None, ) -> List[Dict[str, Any]]: """Semantic search for video events.""" conn = self._get_connection() if not conn: return [] try: with conn.cursor() as cur: query = """ SELECT ve.id as event_id, ve.asset_id, a.file_path, ve.start_ms, ve.end_ms, ve.duration_ms, ve.description, ve.subject_part, ve.mood, 1 - (ve.embedding_768 <=> %s::vector) as similarity FROM video_events ve JOIN assets a ON ve.asset_id = a.id WHERE ve.embedding_768 IS NOT NULL """ params = [query_embedding] if subject_part: query += " AND ve.subject_part = %s" params.append(subject_part) if mood: query += " AND ve.mood = %s" params.append(mood) query += " ORDER BY ve.embedding_768 <=> %s::vector LIMIT %s" params.extend([query_embedding, limit]) cur.execute(query, params) columns = [desc[0] for desc in cur.description] results = [] for row in cur.fetchall(): results.append(dict(zip(columns, row))) return results except Exception as e: logger.error(f"Error searching video events: {e}") return [] # ============================================================ # BROCHURE NODE OPERATIONS # ============================================================ def insert_brochure_node(self, node) -> Optional[UUID]: """Insert a brochure node with embedding.""" conn = self._get_connection() if not conn: return None try: with conn.cursor() as cur: embedding = node.embedding_768 if node.embedding_768 else None cur.execute(""" INSERT INTO brochure_nodes (section, title, content, key_features, taglines, spec_highlights, car_part_referenced, tone_tags, embedding_768, page_number, source_pdf) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s) RETURNING id """, ( node.section.value, node.title, node.content, node.key_features, node.taglines, json.dumps(node.spec_highlights) if node.spec_highlights else None, node.car_part_referenced, [t.value for t in node.tone_tags] if node.tone_tags else None, embedding, node.page_number, node.source_pdf, )) result = cur.fetchone() conn.commit() return result[0] if result else None except Exception as e: logger.error(f"Error inserting brochure node: {e}") conn.rollback() return None # ============================================================ # BROCHURE-ASSET MAP OPERATIONS # ============================================================ def insert_brochure_asset_map(self, mapping) -> bool: """Insert a brochure-to-asset mapping.""" conn = self._get_connection() if not conn: return False try: with conn.cursor() as cur: cur.execute(""" INSERT INTO brochure_asset_map (brochure_node_id, asset_id, video_event_id, similarity_score, mapping_type, confidence_score, is_approved, reviewer_notes, rank) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s) ON CONFLICT (brochure_node_id, asset_id, video_event_id) DO UPDATE SET similarity_score = EXCLUDED.similarity_score, confidence_score = EXCLUDED.confidence_score, rank = EXCLUDED.rank, updated_at = now() """, ( mapping.brochure_node_id, mapping.asset_id, mapping.video_event_id, mapping.similarity_score, mapping.mapping_type.value, mapping.confidence_score, mapping.is_approved, mapping.reviewer_notes, mapping.rank, )) conn.commit() return True except Exception as e: logger.error(f"Error inserting brochure asset map: {e}") conn.rollback() return False def get_brochure_mapped_assets( self, car_part: Optional[str] = None, limit: int = 10, ) -> List[Dict[str, Any]]: """Get assets mapped to brochure nodes mentioning a car part.""" conn = self._get_connection() if not conn: return [] try: with conn.cursor() as cur: query = """ SELECT bam.brochure_node_id, bam.asset_id, bam.video_event_id, bam.similarity_score, bam.mapping_type, bam.confidence_score, a.file_path, ve.start_ms, ve.end_ms, bn.car_part_referenced FROM brochure_asset_map bam JOIN brochure_nodes bn ON bam.brochure_node_id = bn.id JOIN assets a ON bam.asset_id = a.id LEFT JOIN video_events ve ON bam.video_event_id = ve.id WHERE bam.is_approved IS NOT FALSE """ params = [] if car_part: query += " AND %s = ANY(bn.car_part_referenced)" params.append(car_part) query += " ORDER BY bam.confidence_score DESC LIMIT %s" params.append(limit) cur.execute(query, params) columns = [desc[0] for desc in cur.description] results = [] for row in cur.fetchall(): results.append(dict(zip(columns, row))) return results except Exception as e: logger.error(f"Error getting brochure mapped assets: {e}") return [] # ============================================================ # CAPTION & VOICEOVER OPERATIONS # ============================================================ def insert_caption(self, caption) -> Optional[UUID]: """Insert a caption variant.""" conn = self._get_connection() if not conn: return None try: with conn.cursor() as cur: cur.execute(""" INSERT INTO captions_library (brochure_node_id, car_part, tone, duration_class, text, is_brand_compliant, compliance_notes) VALUES (%s, %s, %s, %s, %s, %s, %s) RETURNING id """, ( caption.brochure_node_id, caption.car_part, caption.tone.value, caption.duration_class.value, caption.text, caption.is_brand_compliant, caption.compliance_notes, )) result = cur.fetchone() conn.commit() return result[0] if result else None except Exception as e: logger.error(f"Error inserting caption: {e}") conn.rollback() return None def insert_voiceover(self, voiceover) -> Optional[UUID]: """Insert a voiceover line.""" conn = self._get_connection() if not conn: return None try: with conn.cursor() as cur: cur.execute(""" INSERT INTO voiceover_library (brochure_node_id, car_part, tone, duration_class, text, estimated_duration_ms, is_brand_compliant, compliance_notes) VALUES (%s, %s, %s, %s, %s, %s, %s, %s) RETURNING id """, ( voiceover.brochure_node_id, voiceover.car_part, voiceover.tone.value, voiceover.duration_class.value, voiceover.text, voiceover.estimated_duration_ms, voiceover.is_brand_compliant, voiceover.compliance_notes, )) result = cur.fetchone() conn.commit() return result[0] if result else None except Exception as e: logger.error(f"Error inserting voiceover: {e}") conn.rollback() return None def query_captions( self, car_part: Optional[str] = None, tone=None, duration_class=None, limit: int = 10, ) -> List[Any]: """Query caption library with filters.""" conn = self._get_connection() if not conn: return [] try: with conn.cursor() as cur: query = "SELECT * FROM captions_library WHERE 1=1" params = [] if car_part: query += " AND car_part = %s" params.append(car_part) if tone: query += " AND tone = %s" params.append(tone.value if hasattr(tone, 'value') else tone) if duration_class: query += " AND duration_class = %s" params.append(duration_class.value if hasattr(duration_class, 'value') else duration_class) query += " ORDER BY usage_count ASC LIMIT %s" params.append(limit) cur.execute(query, params) # Return raw dicts (would map to CaptionVariant in production) columns = [desc[0] for desc in cur.description] return [dict(zip(columns, row)) for row in cur.fetchall()] except Exception as e: logger.error(f"Error querying captions: {e}") return [] def query_voiceovers( self, car_part: Optional[str] = None, tone=None, duration_class=None, limit: int = 10, ) -> List[Any]: """Query voiceover library with filters.""" conn = self._get_connection() if not conn: return [] try: with conn.cursor() as cur: query = "SELECT * FROM voiceover_library WHERE 1=1" params = [] if car_part: query += " AND car_part = %s" params.append(car_part) if tone: query += " AND tone = %s" params.append(tone.value if hasattr(tone, 'value') else tone) if duration_class: query += " AND duration_class = %s" params.append(duration_class.value if hasattr(duration_class, 'value') else duration_class) query += " ORDER BY usage_count ASC LIMIT %s" params.append(limit) cur.execute(query, params) columns = [desc[0] for desc in cur.description] return [dict(zip(columns, row)) for row in cur.fetchall()] except Exception as e: logger.error(f"Error querying voiceovers: {e}") return [] def get_mapping_stats(self) -> Dict[str, int]: """Get statistics on brochure-asset mapping approval status.""" conn = self._get_connection() if not conn: return {} try: with conn.cursor() as cur: cur.execute(""" SELECT CASE WHEN is_approved = true THEN 'approved' WHEN is_approved = false THEN 'rejected' ELSE 'pending' END as status, COUNT(*) as count FROM brochure_asset_map GROUP BY is_approved """) return {row[0]: row[1] for row in cur.fetchall()} except Exception as e: logger.error(f"Error getting mapping stats: {e}") return {}