ai-reel-creator-platform / src /utils /embedding_store.py
acd23's picture
Upload src/utils/embedding_store.py with huggingface_hub
226bc4e verified
"""
EmbeddingStore — Unified interface for vector storage and retrieval.
Wraps PostgreSQL + pgvector for CLIP embeddings and asset/brochure queries.
"""
import logging
from typing import List, Dict, Optional, Any
from uuid import UUID
import json
import numpy as np
logger = logging.getLogger(__name__)
class EmbeddingStore:
"""
PostgreSQL + pgvector embedding store for the Reel Creator Platform.
Provides:
- Asset metadata CRUD with vector search
- Brochure node storage with CLIP embeddings
- Brochure-to-asset mapping storage
- Caption and voiceover library queries
- Semantic search via cosine similarity
"""
def __init__(
self,
db_connection_string: Optional[str] = None,
pool_size: int = 10,
):
self.connection_string = db_connection_string
self._pool = None
self._clip_model = None
self._clip_processor = None
def _get_connection(self):
"""Lazy-init database connection pool."""
if self._pool is None:
try:
import psycopg2
from psycopg2.extras import RealDictCursor
if self.connection_string:
self._pool = psycopg2.connect(self.connection_string)
else:
# Default local connection
self._pool = psycopg2.connect(
host="localhost",
port=5432,
dbname="reel_creator",
user="reel_user",
)
except ImportError:
raise ImportError(
"PostgreSQL support requires 'psycopg2-binary'. "
"Install: pip install psycopg2-binary"
)
except Exception as e:
logger.warning(f"Database connection not available: {e}")
self._pool = None
return self._pool
def _get_clip_model(self):
"""Lazy-load CLIP model for text embedding."""
if self._clip_model is None:
try:
import torch
from transformers import CLIPModel, CLIPProcessor
model_name = "openai/clip-vit-large-patch14"
device = "cuda" if torch.cuda.is_available() else "cpu"
logger.info(f"Loading CLIP model: {model_name} on {device}")
self._clip_model = CLIPModel.from_pretrained(model_name).to(device)
self._clip_processor = CLIPProcessor.from_pretrained(model_name)
self._clip_device = device
except Exception as e:
logger.error(f"Failed to load CLIP model: {e}")
raise
return self._clip_model, self._clip_processor
def embed_text(self, text: str) -> List[float]:
"""Generate CLIP text embedding for a query string."""
import torch
model, processor = self._get_clip_model()
inputs = processor(
text=[text],
return_tensors="pt",
padding=True,
truncation=True,
max_length=77,
).to(self._clip_device)
with torch.no_grad():
text_features = model.get_text_features(**inputs)
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
return text_features.cpu().numpy()[0].tolist()
def embed_image(self, image_path: str) -> List[float]:
"""Generate CLIP image embedding."""
import torch
from PIL import Image
model, processor = self._get_clip_model()
image = Image.open(image_path).convert("RGB")
inputs = processor(
images=image,
return_tensors="pt",
).to(self._clip_device)
with torch.no_grad():
image_features = model.get_image_features(**inputs)
image_features = image_features / image_features.norm(dim=-1, keepdim=True)
return image_features.cpu().numpy()[0].tolist()
# ============================================================
# ASSET OPERATIONS
# ============================================================
def insert_asset(self, asset_data: Dict[str, Any]) -> Optional[UUID]:
"""Insert an asset record."""
conn = self._get_connection()
if not conn:
logger.warning("No DB connection, skipping insert")
return None
try:
with conn.cursor() as cur:
cur.execute("""
INSERT INTO assets (file_path, file_name, asset_type, source,
resolution, duration_ms, frame_rate, file_size_bytes)
VALUES (%s, %s, %s, %s, %s, %s, %s, %s)
RETURNING id
""", (
asset_data["file_path"],
asset_data["file_name"],
asset_data["asset_type"],
asset_data.get("source"),
asset_data.get("resolution"),
asset_data.get("duration_ms"),
asset_data.get("frame_rate"),
asset_data.get("file_size_bytes"),
))
result = cur.fetchone()
conn.commit()
return result[0] if result else None
except Exception as e:
logger.error(f"Error inserting asset: {e}")
conn.rollback()
return None
def insert_asset_metadata(
self,
asset_id: UUID,
metadata: Dict[str, Any],
embedding: Optional[List[float]] = None,
) -> bool:
"""Insert asset metadata with optional CLIP embedding."""
conn = self._get_connection()
if not conn:
return False
try:
with conn.cursor() as cur:
embedding_vector = embedding if embedding else None
cur.execute("""
INSERT INTO asset_metadata
(asset_id, description, shot_type, camera_angle, subject_part,
mood, dominant_colours, confidence_score, review_flag, embedding_768)
VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
ON CONFLICT (asset_id) DO UPDATE SET
description = EXCLUDED.description,
shot_type = EXCLUDED.shot_type,
camera_angle = EXCLUDED.camera_angle,
subject_part = EXCLUDED.subject_part,
mood = EXCLUDED.mood,
dominant_colours = EXCLUDED.dominant_colours,
confidence_score = EXCLUDED.confidence_score,
review_flag = EXCLUDED.review_flag,
embedding_768 = EXCLUDED.embedding_768
""", (
asset_id,
metadata.get("description"),
metadata.get("shot_type"),
metadata.get("camera_angle"),
metadata.get("subject_part"),
metadata.get("mood"),
metadata.get("dominant_colours"),
metadata.get("confidence_score", 1.0),
metadata.get("review_flag", False),
embedding_vector,
))
conn.commit()
return True
except Exception as e:
logger.error(f"Error inserting asset metadata: {e}")
conn.rollback()
return False
def search_assets(
self,
query_embedding: List[float],
limit: int = 10,
asset_type: Optional[str] = None,
subject_part: Optional[str] = None,
mood: Optional[str] = None,
min_confidence: float = 0.5,
) -> List[Dict[str, Any]]:
"""Semantic search for assets using cosine similarity."""
conn = self._get_connection()
if not conn:
logger.warning("No DB connection, returning empty results")
return []
try:
with conn.cursor() as cur:
query = """
SELECT
a.id as asset_id,
a.file_path,
a.asset_type,
am.description,
am.shot_type,
am.camera_angle,
am.subject_part,
am.mood,
1 - (am.embedding_768 <=> %s::vector) as similarity
FROM assets a
JOIN asset_metadata am ON a.id = am.asset_id
WHERE am.embedding_768 IS NOT NULL
AND am.confidence_score >= %s
"""
params = [query_embedding, min_confidence]
if asset_type:
query += " AND a.asset_type = %s"
params.append(asset_type)
if subject_part:
query += " AND am.subject_part = %s"
params.append(subject_part)
if mood:
query += " AND am.mood = %s"
params.append(mood)
query += " ORDER BY am.embedding_768 <=> %s::vector LIMIT %s"
params.extend([query_embedding, limit])
cur.execute(query, params)
columns = [desc[0] for desc in cur.description]
results = []
for row in cur.fetchall():
results.append(dict(zip(columns, row)))
return results
except Exception as e:
logger.error(f"Error searching assets: {e}")
return []
# ============================================================
# VIDEO EVENT OPERATIONS
# ============================================================
def insert_video_event(self, event_data: Dict[str, Any]) -> Optional[UUID]:
"""Insert a video event segment."""
conn = self._get_connection()
if not conn:
return None
try:
with conn.cursor() as cur:
cur.execute("""
INSERT INTO video_events
(asset_id, start_ms, end_ms, description, shot_type,
camera_angle, subject_part, mood, embedding_768, confidence_score, keyframe_path)
VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
RETURNING id
""", (
event_data["asset_id"],
event_data["start_ms"],
event_data["end_ms"],
event_data.get("description"),
event_data.get("shot_type"),
event_data.get("camera_angle"),
event_data.get("subject_part"),
event_data.get("mood"),
event_data.get("embedding_768"),
event_data.get("confidence_score", 1.0),
event_data.get("keyframe_path"),
))
result = cur.fetchone()
conn.commit()
return result[0] if result else None
except Exception as e:
logger.error(f"Error inserting video event: {e}")
conn.rollback()
return None
def search_video_events(
self,
query_embedding: List[float],
limit: int = 10,
subject_part: Optional[str] = None,
mood: Optional[str] = None,
) -> List[Dict[str, Any]]:
"""Semantic search for video events."""
conn = self._get_connection()
if not conn:
return []
try:
with conn.cursor() as cur:
query = """
SELECT
ve.id as event_id,
ve.asset_id,
a.file_path,
ve.start_ms,
ve.end_ms,
ve.duration_ms,
ve.description,
ve.subject_part,
ve.mood,
1 - (ve.embedding_768 <=> %s::vector) as similarity
FROM video_events ve
JOIN assets a ON ve.asset_id = a.id
WHERE ve.embedding_768 IS NOT NULL
"""
params = [query_embedding]
if subject_part:
query += " AND ve.subject_part = %s"
params.append(subject_part)
if mood:
query += " AND ve.mood = %s"
params.append(mood)
query += " ORDER BY ve.embedding_768 <=> %s::vector LIMIT %s"
params.extend([query_embedding, limit])
cur.execute(query, params)
columns = [desc[0] for desc in cur.description]
results = []
for row in cur.fetchall():
results.append(dict(zip(columns, row)))
return results
except Exception as e:
logger.error(f"Error searching video events: {e}")
return []
# ============================================================
# BROCHURE NODE OPERATIONS
# ============================================================
def insert_brochure_node(self, node) -> Optional[UUID]:
"""Insert a brochure node with embedding."""
conn = self._get_connection()
if not conn:
return None
try:
with conn.cursor() as cur:
embedding = node.embedding_768 if node.embedding_768 else None
cur.execute("""
INSERT INTO brochure_nodes
(section, title, content, key_features, taglines,
spec_highlights, car_part_referenced, tone_tags,
embedding_768, page_number, source_pdf)
VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
RETURNING id
""", (
node.section.value,
node.title,
node.content,
node.key_features,
node.taglines,
json.dumps(node.spec_highlights) if node.spec_highlights else None,
node.car_part_referenced,
[t.value for t in node.tone_tags] if node.tone_tags else None,
embedding,
node.page_number,
node.source_pdf,
))
result = cur.fetchone()
conn.commit()
return result[0] if result else None
except Exception as e:
logger.error(f"Error inserting brochure node: {e}")
conn.rollback()
return None
# ============================================================
# BROCHURE-ASSET MAP OPERATIONS
# ============================================================
def insert_brochure_asset_map(self, mapping) -> bool:
"""Insert a brochure-to-asset mapping."""
conn = self._get_connection()
if not conn:
return False
try:
with conn.cursor() as cur:
cur.execute("""
INSERT INTO brochure_asset_map
(brochure_node_id, asset_id, video_event_id, similarity_score,
mapping_type, confidence_score, is_approved, reviewer_notes, rank)
VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s)
ON CONFLICT (brochure_node_id, asset_id, video_event_id) DO UPDATE SET
similarity_score = EXCLUDED.similarity_score,
confidence_score = EXCLUDED.confidence_score,
rank = EXCLUDED.rank,
updated_at = now()
""", (
mapping.brochure_node_id,
mapping.asset_id,
mapping.video_event_id,
mapping.similarity_score,
mapping.mapping_type.value,
mapping.confidence_score,
mapping.is_approved,
mapping.reviewer_notes,
mapping.rank,
))
conn.commit()
return True
except Exception as e:
logger.error(f"Error inserting brochure asset map: {e}")
conn.rollback()
return False
def get_brochure_mapped_assets(
self,
car_part: Optional[str] = None,
limit: int = 10,
) -> List[Dict[str, Any]]:
"""Get assets mapped to brochure nodes mentioning a car part."""
conn = self._get_connection()
if not conn:
return []
try:
with conn.cursor() as cur:
query = """
SELECT
bam.brochure_node_id,
bam.asset_id,
bam.video_event_id,
bam.similarity_score,
bam.mapping_type,
bam.confidence_score,
a.file_path,
ve.start_ms,
ve.end_ms,
bn.car_part_referenced
FROM brochure_asset_map bam
JOIN brochure_nodes bn ON bam.brochure_node_id = bn.id
JOIN assets a ON bam.asset_id = a.id
LEFT JOIN video_events ve ON bam.video_event_id = ve.id
WHERE bam.is_approved IS NOT FALSE
"""
params = []
if car_part:
query += " AND %s = ANY(bn.car_part_referenced)"
params.append(car_part)
query += " ORDER BY bam.confidence_score DESC LIMIT %s"
params.append(limit)
cur.execute(query, params)
columns = [desc[0] for desc in cur.description]
results = []
for row in cur.fetchall():
results.append(dict(zip(columns, row)))
return results
except Exception as e:
logger.error(f"Error getting brochure mapped assets: {e}")
return []
# ============================================================
# CAPTION & VOICEOVER OPERATIONS
# ============================================================
def insert_caption(self, caption) -> Optional[UUID]:
"""Insert a caption variant."""
conn = self._get_connection()
if not conn:
return None
try:
with conn.cursor() as cur:
cur.execute("""
INSERT INTO captions_library
(brochure_node_id, car_part, tone, duration_class, text,
is_brand_compliant, compliance_notes)
VALUES (%s, %s, %s, %s, %s, %s, %s)
RETURNING id
""", (
caption.brochure_node_id,
caption.car_part,
caption.tone.value,
caption.duration_class.value,
caption.text,
caption.is_brand_compliant,
caption.compliance_notes,
))
result = cur.fetchone()
conn.commit()
return result[0] if result else None
except Exception as e:
logger.error(f"Error inserting caption: {e}")
conn.rollback()
return None
def insert_voiceover(self, voiceover) -> Optional[UUID]:
"""Insert a voiceover line."""
conn = self._get_connection()
if not conn:
return None
try:
with conn.cursor() as cur:
cur.execute("""
INSERT INTO voiceover_library
(brochure_node_id, car_part, tone, duration_class, text,
estimated_duration_ms, is_brand_compliant, compliance_notes)
VALUES (%s, %s, %s, %s, %s, %s, %s, %s)
RETURNING id
""", (
voiceover.brochure_node_id,
voiceover.car_part,
voiceover.tone.value,
voiceover.duration_class.value,
voiceover.text,
voiceover.estimated_duration_ms,
voiceover.is_brand_compliant,
voiceover.compliance_notes,
))
result = cur.fetchone()
conn.commit()
return result[0] if result else None
except Exception as e:
logger.error(f"Error inserting voiceover: {e}")
conn.rollback()
return None
def query_captions(
self,
car_part: Optional[str] = None,
tone=None,
duration_class=None,
limit: int = 10,
) -> List[Any]:
"""Query caption library with filters."""
conn = self._get_connection()
if not conn:
return []
try:
with conn.cursor() as cur:
query = "SELECT * FROM captions_library WHERE 1=1"
params = []
if car_part:
query += " AND car_part = %s"
params.append(car_part)
if tone:
query += " AND tone = %s"
params.append(tone.value if hasattr(tone, 'value') else tone)
if duration_class:
query += " AND duration_class = %s"
params.append(duration_class.value if hasattr(duration_class, 'value') else duration_class)
query += " ORDER BY usage_count ASC LIMIT %s"
params.append(limit)
cur.execute(query, params)
# Return raw dicts (would map to CaptionVariant in production)
columns = [desc[0] for desc in cur.description]
return [dict(zip(columns, row)) for row in cur.fetchall()]
except Exception as e:
logger.error(f"Error querying captions: {e}")
return []
def query_voiceovers(
self,
car_part: Optional[str] = None,
tone=None,
duration_class=None,
limit: int = 10,
) -> List[Any]:
"""Query voiceover library with filters."""
conn = self._get_connection()
if not conn:
return []
try:
with conn.cursor() as cur:
query = "SELECT * FROM voiceover_library WHERE 1=1"
params = []
if car_part:
query += " AND car_part = %s"
params.append(car_part)
if tone:
query += " AND tone = %s"
params.append(tone.value if hasattr(tone, 'value') else tone)
if duration_class:
query += " AND duration_class = %s"
params.append(duration_class.value if hasattr(duration_class, 'value') else duration_class)
query += " ORDER BY usage_count ASC LIMIT %s"
params.append(limit)
cur.execute(query, params)
columns = [desc[0] for desc in cur.description]
return [dict(zip(columns, row)) for row in cur.fetchall()]
except Exception as e:
logger.error(f"Error querying voiceovers: {e}")
return []
def get_mapping_stats(self) -> Dict[str, int]:
"""Get statistics on brochure-asset mapping approval status."""
conn = self._get_connection()
if not conn:
return {}
try:
with conn.cursor() as cur:
cur.execute("""
SELECT
CASE
WHEN is_approved = true THEN 'approved'
WHEN is_approved = false THEN 'rejected'
ELSE 'pending'
END as status,
COUNT(*) as count
FROM brochure_asset_map
GROUP BY is_approved
""")
return {row[0]: row[1] for row in cur.fetchall()}
except Exception as e:
logger.error(f"Error getting mapping stats: {e}")
return {}