bpiyush commited on
Commit
2e75ade
·
verified ·
1 Parent(s): ed675d3

Upload demo_usage.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. demo_usage.py +73 -0
demo_usage.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from termcolor import colored
3
+ from modeling_tara import TARA, read_frames_decord
4
+
5
+
6
+ def main():
7
+ print(colored("="*60, 'yellow'))
8
+ print(colored("TARA Model Demo", 'yellow', attrs=['bold']))
9
+ print(colored("="*60, 'yellow'))
10
+
11
+ # Load model from current directory
12
+ print(colored("\n[1/3] Loading model...", 'cyan'))
13
+ model = TARA.from_pretrained(
14
+ ".", # Load from current directory
15
+ device_map='auto',
16
+ torch_dtype=torch.bfloat16,
17
+ )
18
+
19
+ n_params = sum(p.numel() for p in model.model.parameters())
20
+ print(colored(f"✓ Model loaded successfully!", 'green'))
21
+ print(f"Number of parameters: {round(n_params/1e9, 3)}B")
22
+
23
+ # Encode a sample video
24
+ print(colored("\n[2/3] Testing video encoding...", 'cyan'))
25
+ video_path = "./assets/folding_paper.mp4"
26
+
27
+ try:
28
+ video_tensor = read_frames_decord(video_path, num_frames=16)
29
+ video_tensor = video_tensor.unsqueeze(0)
30
+ video_tensor = video_tensor.to(model.model.device)
31
+
32
+ with torch.no_grad():
33
+ video_emb = model.encode_vision(video_tensor).cpu().squeeze(0).float()
34
+
35
+ print(colored("✓ Video encoded successfully!", 'green'))
36
+ print(f"Video shape: {video_tensor.shape}") # torch.Size([1, 16, 3, 240, 426])
37
+ print(f"Video embedding shape: {video_emb.shape}") # torch.Size([4096])
38
+ except FileNotFoundError:
39
+ print(colored(f"⚠ Video file not found: {video_path}", 'red'))
40
+ print(colored(" Please add a video file or update the path in demo_usage.py", 'yellow'))
41
+ video_emb = None
42
+
43
+ # Encode sample texts
44
+ print(colored("\n[3/3] Testing text encoding...", 'cyan'))
45
+ text = ['someone is folding a paper', 'cutting a paper', 'someone is folding a paper']
46
+ # NOTE: It can also take a single string
47
+
48
+ with torch.no_grad():
49
+ text_emb = model.encode_text(text).cpu().float()
50
+
51
+ print(colored("✓ Text encoded successfully!", 'green'))
52
+ print(f"Text: {text}")
53
+ print(f"Text embedding shape: {text_emb.shape}") # torch.Size([3, 4096])
54
+
55
+ # Compute similarities if video was encoded
56
+ if video_emb is not None:
57
+ print(colored("\n[Bonus] Computing video-text similarities...", 'cyan'))
58
+ similarities = torch.cosine_similarity(
59
+ video_emb.unsqueeze(0).unsqueeze(0), # [1, 1, 4096]
60
+ text_emb.unsqueeze(0), # [1, 3, 4096]
61
+ dim=-1
62
+ )
63
+ print(colored("✓ Similarities computed!", 'green'))
64
+ for i, txt in enumerate(text):
65
+ print(f" '{txt}': {similarities[0, i].item():.4f}")
66
+
67
+ print(colored("\n" + "="*60, 'yellow'))
68
+ print(colored("Demo completed successfully! 🎉", 'green', attrs=['bold']))
69
+ print(colored("="*60, 'yellow'))
70
+
71
+
72
+ if __name__ == "__main__":
73
+ main()