Eavn commited on
Commit
0dc8f7a
·
verified ·
1 Parent(s): a9f3c19

Create README.md

Browse files
Files changed (1) hide show
  1. README.md +62 -0
README.md ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from huggingface_hub import hf_hub_download
2
+ import torch
3
+ import clip
4
+ import torch.nn.functional as F
5
+ import numpy as np
6
+ import cv2
7
+ import torchvision.transforms as transforms
8
+
9
+
10
+ def generate_event_image(frames, threshold=10):
11
+ frames = np.array(frames)
12
+ num_frames, height, width, _ = frames.shape
13
+ event_images = []
14
+
15
+ for i in range(1, num_frames):
16
+ diff = cv2.absdiff(frames[i], frames[i-1])
17
+ gray_diff = cv2.cvtColor(diff, cv2.COLOR_RGB2GRAY)
18
+ _, event_image = cv2.threshold(gray_diff, threshold, 255, cv2.THRESH_BINARY)
19
+ event_images.append(event_image)
20
+
21
+ return torch.tensor(event_images).sum(dim=0)
22
+
23
+
24
+ ckpt_path = hf_hub_download(
25
+ repo_id="Eavn/event-clip",
26
+ filename="vitb.pt", # or vitl.pt for pretraining checkpoints
27
+ repo_type="model"
28
+ )
29
+
30
+ model, preprocess = clip.load("ViT-B/32")
31
+ # model, preprocess = clip.load("ViT-L/14")
32
+
33
+ state_dict = torch.load(ckpt_path)["checkpoint"]
34
+ new_state_dict = {}
35
+ for key in state_dict.keys():
36
+ if 'encoder_k' in key:
37
+ new_state_dict[key.replace('encoder_k.', '')] = state_dict[key]
38
+ model.load_state_dict(new_state_dict)
39
+
40
+ transform = transforms.Compose([
41
+ transforms.Resize((224, 224)),
42
+ ])
43
+
44
+ stack_size = 16
45
+ threshold = 10
46
+ clamp = 10
47
+ text = 'Put the Text Here'
48
+ text = clip.tokenize([text]).cuda()
49
+
50
+ images = (np.random.rand(32, 224, 224, 3) * 255).astype(np.uint8)
51
+ event = generate_event_image(
52
+ images[:stack_size],
53
+ threshold=threshold
54
+ )
55
+ if clamp > 0:
56
+ event = torch.clamp(event, min=0, max=clamp)
57
+ event = event / event.max()
58
+ event = torch.stack([event, event, event])
59
+ event = transform(event)
60
+ event = event.cuda().unsqueeze(0)
61
+
62
+ logits_per_event, _ = model(event, text)