RAG / utils /engine.py
Hanzo03's picture
Update utils/engine.py
b28f276 verified
raw
history blame
4.33 kB
import os
import shutil
import cv2
import torch
import numpy as np
import zarr
from PIL import Image
from typing import Tuple, List
from utils.config import config, get_logger
from utils.models import device, clip_processor, clip_model, collection, chroma_client, vlm_model, vlm_tokenizer
logger = get_logger("Engine")
def process_and_index_video(video_path: str) -> Tuple[str, List[Image.Image]]:
if not video_path:
return "Please upload a video.", []
if os.path.exists(config.cache_dir):
logger.info(f"Clearing old cache at {config.cache_dir}...")
shutil.rmtree(config.cache_dir, ignore_errors=True)
logger.info("Starting fast extraction process...")
vidcap = cv2.VideoCapture(video_path)
video_fps = vidcap.get(cv2.CAP_PROP_FPS)
frame_interval = max(1, int(video_fps / config.default_fps))
success, first_frame = vidcap.read()
if not success:
return "Failed to read video.", []
rgb_first = cv2.cvtColor(first_frame, cv2.COLOR_BGR2RGB)
h, w, c = rgb_first.shape
logger.info(f"Allocating strict Zarr v3 SSD cache at {config.cache_dir}...")
frame_cache = zarr.create_array(
config.cache_dir, shape=(0, h, w, c), chunks=(10, h, w, c), dtype='uint8', zarr_format=3
)
timestamps, count, frame_idx = [], 0, 0
while success:
if count % frame_interval == 0:
rgb_image = cv2.cvtColor(first_frame, cv2.COLOR_BGR2RGB)
frame_cache.append(np.expand_dims(rgb_image, axis=0), axis=0)
timestamps.append(count / video_fps)
frame_idx += 1
success, first_frame = vidcap.read()
count += 1
vidcap.release()
logger.info("Generating CLIP embeddings in batches...")
all_embeddings = []
total_frames = frame_cache.shape[0]
for i in range(0, total_frames, config.batch_size):
batch_arrays = frame_cache[i : i + config.batch_size]
batch_pil = [Image.fromarray(arr) for arr in batch_arrays]
inputs = clip_processor(images=batch_pil, return_tensors="pt").to(device)
with torch.no_grad():
# 🚨 BUGFIX: Manually extract and project the vision features
vision_outputs = clip_model.vision_model(**inputs)
features = clip_model.visual_projection(vision_outputs.pooler_output)
normalized = (features / features.norm(p=2, dim=-1, keepdim=True)).cpu().tolist()
all_embeddings.extend(normalized)
logger.info("Indexing into ChromaDB...")
ids = [f"frame_{i}" for i in range(total_frames)]
metadatas = [{"timestamp": ts, "frame_idx": i} for i, ts in enumerate(timestamps)]
global collection
chroma_client.delete_collection(config.collection_name)
collection = chroma_client.create_collection(config.collection_name)
collection.add(embeddings=all_embeddings, metadatas=metadatas, ids=ids)
sample_frames = [Image.fromarray(frame_cache[i]) for i in range(min(3, total_frames))]
return f"Processed {total_frames} frames strictly on SSD cache.", sample_frames
def ask_video_question(query: str) -> Tuple[str, List[Image.Image]]:
if collection.count() == 0:
return "Please process a video first.", []
logger.info(f"Processing query: '{query}'")
inputs = clip_processor(text=[query], return_tensors="pt", padding=True).to(device)
with torch.no_grad():
# 🚨 BUGFIX: Manually extract and project the text features
text_outputs = clip_model.text_model(**inputs)
text_features = clip_model.text_projection(text_outputs.pooler_output)
text_embedding = (text_features / text_features.norm(p=2, dim=-1, keepdim=True)).cpu().tolist()
results = collection.query(query_embeddings=text_embedding, n_results=3)
frame_cache = zarr.open_array(config.cache_dir, mode="r")
retrieved_images = []
for metadata in results['metadatas'][0]:
img_array = frame_cache[int(metadata['frame_idx'])]
retrieved_images.append(Image.fromarray(img_array))
logger.info("Generating VLM answer...")
encoded_image = vlm_model.encode_image(retrieved_images[0])
answer = vlm_model.answer_question(encoded_image, query, vlm_tokenizer)
return answer, retrieved_images