Hayloo9838 commited on
Commit
fd8732a
·
verified ·
1 Parent(s): beb486c

Create model.py

Browse files

Standalone usage.

Files changed (1) hide show
  1. model.py +121 -0
model.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import torch
3
+ import numpy as np
4
+ from transformers import CLIPProcessor, CLIPVisionModel
5
+ from PIL import Image
6
+ from torch import nn
7
+
8
+ MODEL_PATH = "clip_large.pth"
9
+
10
+ class CLIPVisionClassifier(nn.Module):
11
+ def __init__(self, num_labels):
12
+ super().__init__()
13
+ self.vision_model = CLIPVisionModel.from_pretrained('openai/clip-vit-large-patch14',
14
+ attn_implementation="eager") # shows heat
15
+ self.classifier = nn.Linear(self.vision_model.config.hidden_size, num_labels, bias=False)
16
+ self.dropout = nn.Dropout(0.1)
17
+
18
+ def forward(self, pixel_values, output_attentions=False):
19
+ outputs = self.vision_model(pixel_values, output_attentions=output_attentions)
20
+ pooled_output = outputs.pooler_output
21
+ logits = self.classifier(pooled_output)
22
+
23
+ if output_attentions:
24
+ return logits, outputs.attentions
25
+ return logits
26
+
27
+ def get_attention_map(attentions, image_size=(224, 224)):
28
+ attention = attentions[-1]
29
+ attention = attention.mean(dim=1)
30
+ attention = attention[0, 0, 1:]
31
+
32
+ num_patches = int(np.sqrt(attention.shape[0]))
33
+
34
+ attention_map = attention.reshape(num_patches, num_patches)
35
+
36
+ attention_map = attention_map.cpu().numpy()
37
+ attention_map = cv2.resize(attention_map, image_size, interpolation=cv2.INTER_LINEAR)
38
+
39
+ attention_map = (attention_map - attention_map.min()) / (attention_map.max() - attention_map.min())
40
+ return attention_map
41
+
42
+ def apply_heatmap(image, attention_map, new_size=(640, 480)):
43
+ heatmap = cv2.applyColorMap(np.uint8(255 * attention_map), cv2.COLORMAP_JET)
44
+
45
+ if isinstance(image, Image.Image):
46
+ image = np.array(image)
47
+ image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
48
+
49
+ image_resized = cv2.resize(image, new_size)
50
+
51
+ attention_map_resized = cv2.resize(attention_map, new_size, interpolation=cv2.INTER_LINEAR)
52
+
53
+ attention_map_resized = (attention_map_resized - attention_map_resized.min()) / (attention_map_resized.max() - attention_map_resized.min())
54
+
55
+ heatmap_resized = cv2.applyColorMap(np.uint8(255 * attention_map_resized), cv2.COLORMAP_JET)
56
+
57
+ output = cv2.addWeighted(image_resized, 0.7, heatmap_resized, 0.3, 0)
58
+
59
+ return output
60
+
61
+ def webcam_card_detection():
62
+ model, processor, reverse_mapping, device = load_model()
63
+
64
+ cap = cv2.VideoCapture(0)
65
+ print("Press 'q' to quit.")
66
+
67
+ while True:
68
+ ret, frame = cap.read()
69
+ if not ret:
70
+ print("Failed to capture image. Exiting...")
71
+ break
72
+
73
+ frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
74
+ image = Image.fromarray(frame_rgb)
75
+
76
+ inputs = processor(images=image, return_tensors="pt")
77
+ pixel_values = inputs.pixel_values.to(device)
78
+
79
+ with torch.no_grad():
80
+ logits, attentions = model(pixel_values, output_attentions=True)
81
+ probs = torch.nn.functional.softmax(logits, dim=-1)
82
+ prediction = torch.argmax(probs).item()
83
+
84
+ # Generate attention map
85
+ attention_map = get_attention_map(attentions)
86
+
87
+ visualization = apply_heatmap(frame, attention_map, new_size=(640, 480))
88
+
89
+ card_name = reverse_mapping[prediction]
90
+ confidence = probs[0][prediction].item()
91
+
92
+ cv2.putText(visualization, f"{card_name} ({confidence:.2%})", (10, 50),
93
+ cv2.FONT_HERSHEY_SIMPLEX, 1, (1, 255, 255), 2, cv2.LINE_AA)
94
+
95
+ cv2.imshow("UNO Card Detection", visualization)
96
+
97
+ if cv2.waitKey(1) & 0xFF == ord('q'):
98
+ print("Exiting...")
99
+ break
100
+
101
+ cap.release()
102
+ cv2.destroyAllWindows()
103
+
104
+
105
+ def load_model():
106
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
107
+ checkpoint = torch.load(MODEL_PATH, map_location=device)
108
+ label_mapping = checkpoint['label_mapping']
109
+ reverse_mapping = {v: k for k, v in label_mapping.items()}
110
+
111
+ model = CLIPVisionClassifier(len(label_mapping))
112
+ model.load_state_dict(checkpoint['model_state_dict'])
113
+ model = model.to(device)
114
+ model.eval()
115
+
116
+ processor = CLIPProcessor.from_pretrained('openai/clip-vit-large-patch14')
117
+
118
+ return model, processor, reverse_mapping, device
119
+
120
+ if __name__ == "__main__":
121
+ webcam_card_detection()