SajayR commited on
Commit
a3ce7b0
·
verified ·
1 Parent(s): d00ed97

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +179 -6
README.md CHANGED
@@ -1,10 +1,183 @@
1
  ---
2
  license: mit
3
- tags:
4
- - model_hub_mixin
5
- - pytorch_model_hub_mixin
6
  ---
7
 
8
- This model has been pushed to the Hub using the [PytorchModelHubMixin](https://huggingface.co/docs/huggingface_hub/package_reference/mixins#huggingface_hub.PyTorchModelHubMixin) integration:
9
- - Library: [More Information Needed]
10
- - Docs: [More Information Needed]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
  license: mit
 
 
 
3
  ---
4
 
5
+ # Triad: Dense Cross-Modal Feature Learning
6
+ ![image/png](https://cdn-uploads.huggingface.co/production/uploads/64792e9d50ff700163188784/2o6JBAgVerp5sUVM7WChK.png)
7
+
8
+ I built Triad to explore dense feature correspondences between video, audio and text modalities - focusing on learning fine-grained, localized relationships rather than just global alignment. The goal was to create a model that could ground features between specific image regions, audio segments, and text spans simultaneously.
9
+
10
+ This is a very early research checkpoint for dense multi-modal learning, with lots of room for improvement and experimentation. The current model was trained on a subset of AudioSet (\~400k videos, \~20% of the entire dataset) and CC3M (\~2M image-text pairs) for just one epoch, so while it shows promising behavior, it's definitely not state-of-the-art yet.
11
+ ## What Makes This Interesting?
12
+ Unlike models that learn global alignment between modalities (think CLIP, ImageBind), Triad learns to map specific parts of each modality to each other. This means it can:
13
+ - Locate which parts of an image correspond to particular words or sounds
14
+ - Ground audio segments to relevant visual regions
15
+ - Connect text descriptions to precise areas in images
16
+ - (Potentially) Learn transitive audio-text relationships through the shared visual space
17
+
18
+ ## What's Next?
19
+ I've got lots of ideas for making this better - longer training, playing with the architecture, investigating some interesting behaviors I've noticed and solving that massive issue of dealing with text, audio features that do not exist in the visual features.
20
+
21
+ I'm actively looking to push this research further and super interested in tackling more multimodal learning problems. Feel free to reach out if you're working in this space!
22
+
23
+ ## Inference
24
+
25
+ # Triad Model
26
+
27
+ The model can process image, audio, and text inputs - either individually or together.
28
+
29
+ ## Installation & Loading
30
+
31
+ ```python
32
+ from safetensors.torch import load_file
33
+ from huggingface_hub import hf_hub_download
34
+ import torch
35
+ import json
36
+ import sys
37
+ from pathlib import Path
38
+
39
+ def load_model(path="SajayR/Triad", device="cpu"):
40
+ model_path = hf_hub_download(repo_id=path, filename="model.safetensors")
41
+ model_config = hf_hub_download(repo_id=path, filename="config.json")
42
+ model_arch = hf_hub_download(repo_id=path, filename="hf_model.py")
43
+
44
+ sys.path.append(str(Path(model_arch).parent))
45
+ from hf_model import Triad
46
+
47
+ model = Triad(**json.load(open(model_config)))
48
+ weights = load_file(model_path)
49
+ model.load_state_dict(weights)
50
+ return model.to(device)
51
+
52
+ # Initialize model
53
+ model = load_model() # Use load_model(device="cuda") for GPU
54
+ ```
55
+
56
+ ## Single Modality Examples
57
+
58
+ ### Image Input
59
+
60
+ You can provide images as file paths or tensors:
61
+
62
+ ```python
63
+ # From file path
64
+ output = model(image="path/to/image.jpg")
65
+ output['visual_feats'].shape # torch.Size([1, 256, 512])
66
+
67
+ # From tensor (already pre-processed)
68
+ from torchvision import transforms
69
+ from PIL import Image
70
+
71
+ # Load and preprocess image
72
+ image = Image.open("path/to/image.jpg").convert('RGB')
73
+ transform = transforms.Compose([
74
+ transforms.Resize((224, 224)),
75
+ transforms.ToTensor(),
76
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
77
+ ])
78
+ image_tensor = transform(image) # Shape: [3, 224, 224]
79
+
80
+ # Pass to model
81
+ output = model(image=image_tensor)
82
+ output['visual_feats'].shape # torch.Size([1, 256, 512])
83
+ ```
84
+
85
+ ### Audio Input
86
+
87
+ ```python
88
+ # Audio only - returns audio features (B, N_segments, D)
89
+ # Currently is trained for audio features of 1 seconds each. Longer audio sequences could have worse performance
90
+ audio = torch.randn(1, 16331) # Raw audio waveform
91
+ output = model(audio=audio)
92
+ output['audio_feats'].shape # torch.Size([1, 50, 512])
93
+ ```
94
+
95
+ ### Text Input
96
+
97
+ ```python
98
+ # Text only - returns text features (B, N_tokens, D)
99
+ text_list = ["a man riding a bicycle"]
100
+ output = model(text_list=text_list)
101
+ output['text_feats'].shape # torch.Size([1, 5, 512])
102
+ ```
103
+
104
+ ## Batch Processing
105
+
106
+ The model now supports batch processing for image inputs:
107
+
108
+ ### Batch of Image Paths
109
+
110
+ ```python
111
+ # Process a batch of image paths
112
+ image_paths = ["path/to/image1.jpg", "path/to/image2.jpg", "path/to/image3.jpg"]
113
+ output = model(image=image_paths)
114
+ output['visual_feats'].shape # torch.Size([3, 256, 512])
115
+ ```
116
+
117
+ ### Batch of Image Tensors
118
+
119
+ ```python
120
+ # Process a batch of image tensors
121
+ import torch
122
+ from torchvision import transforms
123
+ from PIL import Image
124
+
125
+ # Create a transform
126
+ transform = transforms.Compose([
127
+ transforms.Resize((224, 224)),
128
+ transforms.ToTensor(),
129
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
130
+ ])
131
+
132
+ # Load and preprocess images
133
+ images = []
134
+ for path in ["image1.jpg", "image2.jpg", "image3.jpg"]:
135
+ img = Image.open(path).convert('RGB')
136
+ images.append(transform(img))
137
+
138
+ # Stack into a batch
139
+ batch = torch.stack(images) # Shape: [3, 3, 224, 224]
140
+
141
+ # Process the batch
142
+ output = model(image=batch)
143
+ output['visual_feats'].shape # torch.Size([3, 256, 512])
144
+ ```
145
+
146
+ ## Multi-Modal Examples
147
+
148
+ ### Image and Audio Together
149
+
150
+ ```python
151
+ # Process image and audio together
152
+ output = model(
153
+ audio=audio,
154
+ image="path/to/image.jpg"
155
+ )
156
+
157
+ print(output.keys()) # dict_keys(['visual_feats', 'audio_feats', 'vis_audio_sim_matrix'])
158
+
159
+ # Output shapes:
160
+ # - audio_feats: [1, 50, 512] # (batch, audio_segments, features)
161
+ # - visual_feats: [1, 256, 512] # (batch, image_patches, features)
162
+ # - vis_audio_sim_matrix: [1, 50, 256] # (batch, audio_segments, image_patches)
163
+ ```
164
+
165
+ The similarity matrix shows the correspondence between each audio segment and image patch.
166
+
167
+ ## Output Key Reference
168
+
169
+ Depending on which modalities you provide, the model returns different outputs:
170
+
171
+ - `visual_feats`: (B, 256, 512) # When you pass an image
172
+ - `audio_feats`: (B, 50, 512) # When you pass audio
173
+ - `text_feats`: (B, N_tokens, 512) # When you pass text
174
+ - `vis_text_sim_matrix`: (B, N_tokens, 256) # When you pass both image and text
175
+ - `vis_audio_sim_matrix`: (B, 50, 256) # When you pass both image and audio
176
+ - `text_audio_sim_matrix`: (B, N_tokens, 50) # When you pass both text and audio
177
+
178
+ Where:
179
+ - B = batch size
180
+ - 256 = number of image patches
181
+ - 50 = number of audio segments
182
+ - N_tokens = variable length of text tokens
183
+ - 512 = embedding dimension