File size: 4,173 Bytes
04705fd
 
 
 
 
 
 
 
 
4bc8241
 
04705fd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import os
import shutil
import cv2
import torch
import numpy as np
import zarr
from PIL import Image
from typing import Tuple, List

from utils.config import config, get_logger
from utils.models import device, clip_processor, clip_model, collection, chroma_client, vlm_model, vlm_tokenizer
logger = get_logger("Engine")

def process_and_index_video(video_path: str) -> Tuple[str, List[Image.Image]]:
    if not video_path:
        return "Please upload a video.", []

    # Strict Cache Cleanup
    if os.path.exists(config.cache_dir):
        logger.info(f"Clearing old cache at {config.cache_dir}...")
        shutil.rmtree(config.cache_dir, ignore_errors=True)

    logger.info("Starting fast extraction process...")
    vidcap = cv2.VideoCapture(video_path)
    video_fps = vidcap.get(cv2.CAP_PROP_FPS)
    frame_interval = max(1, int(video_fps / config.default_fps))
    
    success, first_frame = vidcap.read()
    if not success:
        return "Failed to read video.", []
        
    rgb_first = cv2.cvtColor(first_frame, cv2.COLOR_BGR2RGB)
    h, w, c = rgb_first.shape
    
    # 🚨 STRICT SSD ALLOCATION
    logger.info(f"Allocating strict Zarr v3 SSD cache at {config.cache_dir}...")
    frame_cache = zarr.create_array(
        config.cache_dir, shape=(0, h, w, c), chunks=(10, h, w, c), dtype='uint8', zarr_format=3
    )

    timestamps, count, frame_idx = [], 0, 0
    
    while success:
        # 🚀 SPEED OPTIMIZATION: Only process exact frames needed
        if count % frame_interval == 0:
            rgb_image = cv2.cvtColor(first_frame, cv2.COLOR_BGR2RGB)
            frame_cache.append(np.expand_dims(rgb_image, axis=0), axis=0)
            timestamps.append(count / video_fps)
            frame_idx += 1
            
        success, first_frame = vidcap.read()
        count += 1
        
    vidcap.release()

    logger.info("Generating CLIP embeddings in batches...")
    all_embeddings = []
    total_frames = frame_cache.shape[0]
    
    for i in range(0, total_frames, config.batch_size):
        batch_arrays = frame_cache[i : i + config.batch_size] 
        batch_pil = [Image.fromarray(arr) for arr in batch_arrays]
        inputs = clip_processor(images=batch_pil, return_tensors="pt").to(device)
        
        with torch.no_grad():
            features = clip_model.get_image_features(**inputs)
            
        normalized = (features / features.norm(p=2, dim=-1, keepdim=True)).cpu().tolist()
        all_embeddings.extend(normalized)

    logger.info("Indexing into ChromaDB...")
    ids = [f"frame_{i}" for i in range(total_frames)]
    metadatas = [{"timestamp": ts, "frame_idx": i} for i, ts in enumerate(timestamps)]
    
    global collection
    chroma_client.delete_collection(config.collection_name)
    collection = chroma_client.create_collection(config.collection_name)
    
    collection.add(embeddings=all_embeddings, metadatas=metadatas, ids=ids)
    
    sample_frames = [Image.fromarray(frame_cache[i]) for i in range(min(3, total_frames))]
    return f"Processed {total_frames} frames strictly on SSD cache.", sample_frames


def ask_video_question(query: str) -> Tuple[str, List[Image.Image]]:
    if collection.count() == 0:
        return "Please process a video first.", []

    logger.info(f"Processing query: '{query}'")
    
    inputs = clip_processor(text=[query], return_tensors="pt", padding=True).to(device)
    with torch.no_grad():
        text_features = clip_model.get_text_features(**inputs)
    text_embedding = (text_features / text_features.norm(p=2, dim=-1, keepdim=True)).cpu().tolist()

    results = collection.query(query_embeddings=text_embedding, n_results=3)
    
    # Read strictly from SSD
    frame_cache = zarr.open_array(config.cache_dir, mode="r")
    
    retrieved_images = []
    for metadata in results['metadatas'][0]:
        img_array = frame_cache[int(metadata['frame_idx'])] 
        retrieved_images.append(Image.fromarray(img_array))

    logger.info("Generating VLM answer...")
    encoded_image = vlm_model.encode_image(retrieved_images[0])
    answer = vlm_model.answer_question(encoded_image, query, vlm_tokenizer)

    return answer, retrieved_images