Hayloo9838 commited on
Commit
07ff5a5
·
verified ·
1 Parent(s): 990ca41

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +66 -59
model.py CHANGED
@@ -1,21 +1,24 @@
1
- # model usage + heatmap of attention (what the model is focusing on)
2
-
3
  import cv2
4
  import torch
5
  import numpy as np
6
  from transformers import CLIPProcessor, CLIPVisionModel
7
  from PIL import Image
8
  from torch import nn
 
 
 
9
 
10
  MODEL_PATH = "pytorch_model.bin"
 
 
11
 
12
  class CLIPVisionClassifier(nn.Module):
13
  def __init__(self, num_labels):
14
  super().__init__()
15
  self.vision_model = CLIPVisionModel.from_pretrained('openai/clip-vit-large-patch14',
16
- attn_implementation="eager") # shows heat
17
  self.classifier = nn.Linear(self.vision_model.config.hidden_size, num_labels, bias=False)
18
- self.dropout = nn.Dropout(0.1) # this is not used dont worry :
19
 
20
  def forward(self, pixel_values, output_attentions=False):
21
  outputs = self.vision_model(pixel_values, output_attentions=output_attentions)
@@ -26,7 +29,7 @@ class CLIPVisionClassifier(nn.Module):
26
  return logits, outputs.attentions
27
  return logits
28
 
29
- def get_attention_map(attentions, image_size=(224, 224)):
30
  attention = attentions[-1]
31
  attention = attention.mean(dim=1)
32
  attention = attention[0, 0, 1:]
@@ -36,82 +39,72 @@ def get_attention_map(attentions, image_size=(224, 224)):
36
  attention_map = attention.reshape(num_patches, num_patches)
37
 
38
  attention_map = attention_map.cpu().numpy()
39
- attention_map = cv2.resize(attention_map, image_size, interpolation=cv2.INTER_LINEAR)
40
 
41
  attention_map = (attention_map - attention_map.min()) / (attention_map.max() - attention_map.min())
42
  return attention_map
43
 
44
- def apply_heatmap(image, attention_map, new_size=(640, 480)):
45
  heatmap = cv2.applyColorMap(np.uint8(255 * attention_map), cv2.COLORMAP_JET)
46
 
47
  if isinstance(image, Image.Image):
48
  image = np.array(image)
49
  image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
 
 
 
 
 
 
 
 
 
 
 
 
 
50
 
51
- image_resized = cv2.resize(image, new_size)
 
 
 
52
 
53
- attention_map_resized = cv2.resize(attention_map, new_size, interpolation=cv2.INTER_LINEAR)
54
 
55
- attention_map_resized = (attention_map_resized - attention_map_resized.min()) / (attention_map_resized.max() - attention_map_resized.min())
 
56
 
57
- heatmap_resized = cv2.applyColorMap(np.uint8(255 * attention_map_resized), cv2.COLORMAP_JET)
 
 
 
58
 
59
- output = cv2.addWeighted(image_resized, 0.7, heatmap_resized, 0.3, 0)
 
60
 
61
- return output
62
-
63
- def webcam_card_detection():
64
- model, processor, reverse_mapping, device = load_model()
65
 
66
- cap = cv2.VideoCapture(0)
67
- print("Press 'q' to quit.")
68
-
69
- while True:
70
- ret, frame = cap.read()
71
- if not ret:
72
- print("Failed to capture image. Exiting...")
73
- break
74
-
75
- frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
76
- image = Image.fromarray(frame_rgb)
77
-
78
- inputs = processor(images=image, return_tensors="pt")
79
- pixel_values = inputs.pixel_values.to(device)
80
-
81
- with torch.no_grad():
82
- logits, attentions = model(pixel_values, output_attentions=True)
83
- probs = torch.nn.functional.softmax(logits, dim=-1)
84
- prediction = torch.argmax(probs).item()
85
-
86
- # Generate attention map
87
- attention_map = get_attention_map(attentions)
88
-
89
- visualization = apply_heatmap(frame, attention_map, new_size=(640, 480))
90
-
91
- card_name = reverse_mapping[prediction]
92
- confidence = probs[0][prediction].item()
93
-
94
- cv2.putText(visualization, f"{card_name} ({confidence:.2%})", (10, 50),
95
- cv2.FONT_HERSHEY_SIMPLEX, 1, (1, 255, 255), 2, cv2.LINE_AA)
96
-
97
- cv2.imshow("UNO Card Detection", visualization)
98
-
99
- if cv2.waitKey(1) & 0xFF == ord('q'):
100
- print("Exiting...")
101
- break
102
 
103
- cap.release()
104
- cv2.destroyAllWindows()
105
-
 
106
 
107
  def load_model():
108
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
109
- checkpoint = torch.load(MODEL_PATH, map_location=device)
 
 
 
 
110
  label_mapping = checkpoint['label_mapping']
111
  reverse_mapping = {v: k for k, v in label_mapping.items()}
112
-
113
  model = CLIPVisionClassifier(len(label_mapping))
114
- model.load_state_dict(checkpoint['model_state_dict'])
 
 
 
115
  model = model.to(device)
116
  model.eval()
117
 
@@ -120,4 +113,18 @@ def load_model():
120
  return model, processor, reverse_mapping, device
121
 
122
  if __name__ == "__main__":
123
- webcam_card_detection()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import cv2
2
  import torch
3
  import numpy as np
4
  from transformers import CLIPProcessor, CLIPVisionModel
5
  from PIL import Image
6
  from torch import nn
7
+ import requests
8
+ import matplotlib.pyplot as plt
9
+ from huggingface_hub import hf_hub_download
10
 
11
  MODEL_PATH = "pytorch_model.bin"
12
+ REPO_ID = "Hayloo9838/uno-recognizer"
13
+ MAPANDSTUFF = "mapandstuff.pth"
14
 
15
  class CLIPVisionClassifier(nn.Module):
16
  def __init__(self, num_labels):
17
  super().__init__()
18
  self.vision_model = CLIPVisionModel.from_pretrained('openai/clip-vit-large-patch14',
19
+ attn_implementation="eager")
20
  self.classifier = nn.Linear(self.vision_model.config.hidden_size, num_labels, bias=False)
21
+ self.dropout = nn.Dropout(0.1)
22
 
23
  def forward(self, pixel_values, output_attentions=False):
24
  outputs = self.vision_model(pixel_values, output_attentions=output_attentions)
 
29
  return logits, outputs.attentions
30
  return logits
31
 
32
+ def get_attention_map(attentions):
33
  attention = attentions[-1]
34
  attention = attention.mean(dim=1)
35
  attention = attention[0, 0, 1:]
 
39
  attention_map = attention.reshape(num_patches, num_patches)
40
 
41
  attention_map = attention_map.cpu().numpy()
 
42
 
43
  attention_map = (attention_map - attention_map.min()) / (attention_map.max() - attention_map.min())
44
  return attention_map
45
 
46
+ def apply_heatmap(image, attention_map, new_size=None):
47
  heatmap = cv2.applyColorMap(np.uint8(255 * attention_map), cv2.COLORMAP_JET)
48
 
49
  if isinstance(image, Image.Image):
50
  image = np.array(image)
51
  image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
52
+
53
+ if new_size is not None:
54
+ image_resized = cv2.resize(image, new_size)
55
+ attention_map_resized = cv2.resize(attention_map, image_resized.shape[:2][::-1] , interpolation=cv2.INTER_LINEAR)
56
+ attention_map_resized = (attention_map_resized - attention_map_resized.min()) / (attention_map_resized.max() - attention_map_resized.min())
57
+ heatmap_resized = cv2.applyColorMap(np.uint8(255 * attention_map_resized), cv2.COLORMAP_JET)
58
+ output = cv2.addWeighted(image_resized, 0.7, heatmap_resized, 0.3, 0)
59
+ else:
60
+ attention_map_resized = cv2.resize(attention_map, image.shape[:2][::-1] , interpolation=cv2.INTER_LINEAR)
61
+ attention_map_resized = (attention_map_resized - attention_map_resized.min()) / (attention_map_resized.max() - attention_map_resized.min())
62
+ heatmap_resized = cv2.applyColorMap(np.uint8(255 * attention_map_resized), cv2.COLORMAP_JET)
63
+ output = cv2.addWeighted(image, 0.7, heatmap_resized, 0.3, 0)
64
+
65
 
66
+ return output
67
+
68
+ def process_image_classification(image_url):
69
+ model, processor, reverse_mapping, device = load_model()
70
 
71
+ image = Image.open(requests.get(image_url, stream=True).raw).convert('RGB')
72
 
73
+ inputs = processor(images=image, return_tensors="pt")
74
+ pixel_values = inputs.pixel_values.to(device)
75
 
76
+ with torch.no_grad():
77
+ logits, attentions = model(pixel_values, output_attentions=True)
78
+ probs = torch.nn.functional.softmax(logits, dim=-1)
79
+ prediction = torch.argmax(probs).item()
80
 
81
+ # Generate attention map
82
+ attention_map = get_attention_map(attentions)
83
 
84
+ visualization = apply_heatmap(image, attention_map)
 
 
 
85
 
86
+ card_name = reverse_mapping[prediction]
87
+ confidence = probs[0][prediction].item()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
 
89
+ # Convert back to RGB for matplotlib display
90
+ visualization_rgb = cv2.cvtColor(visualization, cv2.COLOR_BGR2RGB)
91
+
92
+ return visualization_rgb, card_name, confidence
93
 
94
  def load_model():
95
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
96
+
97
+ # Download model weights and label mapping from Hugging Face Hub
98
+ model_path = hf_hub_download(repo_id=REPO_ID, filename=MODEL_PATH)
99
+ #mapandstuff_path = hf_hub_download(repo_id=REPO_ID, filename=MAPANDSTUFF)
100
+ checkpoint = torch.load(model_path, map_location=device)
101
  label_mapping = checkpoint['label_mapping']
102
  reverse_mapping = {v: k for k, v in label_mapping.items()}
 
103
  model = CLIPVisionClassifier(len(label_mapping))
104
+
105
+ model_state_dict = checkpoint["model_state_dict"]
106
+ model.load_state_dict(model_state_dict)
107
+
108
  model = model.to(device)
109
  model.eval()
110
 
 
113
  return model, processor, reverse_mapping, device
114
 
115
  if __name__ == "__main__":
116
+ image_url = "https://www.shutterstock.com/image-vector/hand-hold-reverse-card-symbol-600w-2360073097.jpg"
117
+ visualization, card_name, confidence = process_image_classification(image_url)
118
+
119
+ plt.figure(figsize=(10, 5))
120
+
121
+ plt.subplot(1, 2, 1)
122
+ plt.imshow(visualization)
123
+ plt.title(f"Heatmap on Image")
124
+ plt.axis('off')
125
+
126
+ plt.subplot(1, 2, 2)
127
+ plt.text(0.5, 0.5, f"Predicted Card: {card_name}\nConfidence: {confidence:.2%}",
128
+ fontsize=12, ha='center', va='center')
129
+ plt.axis('off')
130
+ plt.show()