File size: 4,931 Bytes
3ca196f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2b669dd
 
 
 
 
 
 
3ca196f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2b669dd
 
 
 
 
 
 
 
 
 
 
3ca196f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import torch
from transformers import PretrainedConfig
from typing import List, Optional, Dict, Any, Tuple


class VineConfig(PretrainedConfig):
    """
    Configuration class for VINE (Video Understanding with Natural Language) model.
    
    VINE is a video understanding model that processes categorical (object class names),
    unary keywords (actions on one object), and binary keywords (relations between two objects),
    and returns probability distributions over all of them when passed a video.
    
    Args:
        model_name (str): The CLIP model name to use as backbone. Default: "openai/clip-vit-large-patch14-336"
        hidden_dim (int): Hidden dimension size. Default: 768
        num_top_pairs (int): Number of top object pairs to consider. Default: 10
        segmentation_method (str): Segmentation method to use ("sam2" or "grounding_dino_sam2"). Default: "grounding_dino_sam2"
        box_threshold (float): Box threshold for Grounding DINO. Default: 0.35
        text_threshold (float): Text threshold for Grounding DINO. Default: 0.25
        target_fps (int): Target FPS for video processing. Default: 1
        alpha (float): Alpha value for object extraction. Default: 0.5
        white_alpha (float): White alpha value for background blending. Default: 0.8
        topk_cate (int): Top-k categories to return. Default: 3
        multi_class (bool): Whether to use multi-class classification. Default: False
        output_logit (bool): Whether to output logits instead of probabilities. Default: False
        max_video_length (int): Maximum number of frames to process. Default: 100
        bbox_min_dim (int): Minimum bounding box dimension. Default: 5
        visualize (bool): Whether to visualize results. Default: False
        visualization_dir (str, optional): Directory to save visualizations. Default: None
        debug_visualizations (bool): Whether to save debug visualizations. Default: False
        return_flattened_segments (bool): Whether to return flattened segments. Default: False
        return_valid_pairs (bool): Whether to return valid object pairs. Default: False
        interested_object_pairs (List[Tuple[int, int]], optional): List of interested object pairs
    """
    
    model_type = "vine"
    
    def __init__(
        self,
        model_name: str = "openai/clip-vit-base-patch32",
        hidden_dim = 768,
        
        use_hf_repo: bool = False,
        model_repo: Optional[str] = None,
        model_file: Optional[str] = None,
        local_dir: Optional[str] = None,
        local_filename: Optional[str] = None,        
        
        num_top_pairs: int = 18,
        segmentation_method: str = "grounding_dino_sam2",
        box_threshold: float = 0.35,
        text_threshold: float = 0.25,
        target_fps: int = 1,
        alpha: float = 0.5,
        white_alpha: float = 0.8,
        topk_cate: int = 3,
        multi_class: bool = False,
        output_logit: bool = False,
        max_video_length: int = 100,
        bbox_min_dim: int = 5,
        visualize: bool = False,
        visualization_dir: Optional[str] = None,
        return_flattened_segments: bool = False,
        return_valid_pairs: bool = False,
        interested_object_pairs: Optional[List[Tuple[int, int]]] = None,
        debug_visualizations: bool = False,
        device: Optional[str | int] = None,
        **kwargs
    ):
        self.model_name = model_name
        self.use_hf_repo = use_hf_repo
        if use_hf_repo:
            self.model_repo = model_repo
            self.model_file = model_file
            self.local_dir = None
            self.local_filename = None
        else:
            self.model_repo = None
            self.model_file = None
            self.local_dir = local_dir
            self.local_filename = local_filename
        self.hidden_dim = hidden_dim
        self.num_top_pairs = num_top_pairs
        self.segmentation_method = segmentation_method
        self.box_threshold = box_threshold
        self.text_threshold = text_threshold
        self.target_fps = target_fps
        self.alpha = alpha
        self.white_alpha = white_alpha
        self.topk_cate = topk_cate
        self.multi_class = multi_class
        self.output_logit = output_logit
        self.max_video_length = max_video_length
        self.bbox_min_dim = bbox_min_dim
        self.visualize = visualize
        self.visualization_dir = visualization_dir
        self.return_flattened_segments = return_flattened_segments
        self.return_valid_pairs = return_valid_pairs
        self.interested_object_pairs = interested_object_pairs or []
        self.debug_visualizations = debug_visualizations
        if device is int:
            self._device = f"cuda:{device}" if torch.cuda.is_available() else "cpu"
        else:
            self._device = device or ("cuda" if torch.cuda.is_available() else "cpu")

        super().__init__(**kwargs)