RAG / utils /engine.py
Hanzo03's picture
Update utils/engine.py
4bc8241 verified
raw
history blame
4.17 kB
import os
import shutil
import cv2
import torch
import numpy as np
import zarr
from PIL import Image
from typing import Tuple, List
from utils.config import config, get_logger
from utils.models import device, clip_processor, clip_model, collection, chroma_client, vlm_model, vlm_tokenizer
logger = get_logger("Engine")
def process_and_index_video(video_path: str) -> Tuple[str, List[Image.Image]]:
if not video_path:
return "Please upload a video.", []
# Strict Cache Cleanup
if os.path.exists(config.cache_dir):
logger.info(f"Clearing old cache at {config.cache_dir}...")
shutil.rmtree(config.cache_dir, ignore_errors=True)
logger.info("Starting fast extraction process...")
vidcap = cv2.VideoCapture(video_path)
video_fps = vidcap.get(cv2.CAP_PROP_FPS)
frame_interval = max(1, int(video_fps / config.default_fps))
success, first_frame = vidcap.read()
if not success:
return "Failed to read video.", []
rgb_first = cv2.cvtColor(first_frame, cv2.COLOR_BGR2RGB)
h, w, c = rgb_first.shape
# 🚨 STRICT SSD ALLOCATION
logger.info(f"Allocating strict Zarr v3 SSD cache at {config.cache_dir}...")
frame_cache = zarr.create_array(
config.cache_dir, shape=(0, h, w, c), chunks=(10, h, w, c), dtype='uint8', zarr_format=3
)
timestamps, count, frame_idx = [], 0, 0
while success:
# 🚀 SPEED OPTIMIZATION: Only process exact frames needed
if count % frame_interval == 0:
rgb_image = cv2.cvtColor(first_frame, cv2.COLOR_BGR2RGB)
frame_cache.append(np.expand_dims(rgb_image, axis=0), axis=0)
timestamps.append(count / video_fps)
frame_idx += 1
success, first_frame = vidcap.read()
count += 1
vidcap.release()
logger.info("Generating CLIP embeddings in batches...")
all_embeddings = []
total_frames = frame_cache.shape[0]
for i in range(0, total_frames, config.batch_size):
batch_arrays = frame_cache[i : i + config.batch_size]
batch_pil = [Image.fromarray(arr) for arr in batch_arrays]
inputs = clip_processor(images=batch_pil, return_tensors="pt").to(device)
with torch.no_grad():
features = clip_model.get_image_features(**inputs)
normalized = (features / features.norm(p=2, dim=-1, keepdim=True)).cpu().tolist()
all_embeddings.extend(normalized)
logger.info("Indexing into ChromaDB...")
ids = [f"frame_{i}" for i in range(total_frames)]
metadatas = [{"timestamp": ts, "frame_idx": i} for i, ts in enumerate(timestamps)]
global collection
chroma_client.delete_collection(config.collection_name)
collection = chroma_client.create_collection(config.collection_name)
collection.add(embeddings=all_embeddings, metadatas=metadatas, ids=ids)
sample_frames = [Image.fromarray(frame_cache[i]) for i in range(min(3, total_frames))]
return f"Processed {total_frames} frames strictly on SSD cache.", sample_frames
def ask_video_question(query: str) -> Tuple[str, List[Image.Image]]:
if collection.count() == 0:
return "Please process a video first.", []
logger.info(f"Processing query: '{query}'")
inputs = clip_processor(text=[query], return_tensors="pt", padding=True).to(device)
with torch.no_grad():
text_features = clip_model.get_text_features(**inputs)
text_embedding = (text_features / text_features.norm(p=2, dim=-1, keepdim=True)).cpu().tolist()
results = collection.query(query_embeddings=text_embedding, n_results=3)
# Read strictly from SSD
frame_cache = zarr.open_array(config.cache_dir, mode="r")
retrieved_images = []
for metadata in results['metadatas'][0]:
img_array = frame_cache[int(metadata['frame_idx'])]
retrieved_images.append(Image.fromarray(img_array))
logger.info("Generating VLM answer...")
encoded_image = vlm_model.encode_image(retrieved_images[0])
answer = vlm_model.answer_question(encoded_image, query, vlm_tokenizer)
return answer, retrieved_images