wendys-llc commited on
Commit
45be8eb
·
verified ·
1 Parent(s): bbcd059

Upload model.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. model.py +55 -4
model.py CHANGED
@@ -1,6 +1,10 @@
1
- from transformers import PreTrainedModel, PretrainedConfig
 
2
  import torch.nn as nn
3
  from torchvision.models import efficientnet_v2_s, EfficientNet_V2_S_Weights
 
 
 
4
 
5
  class CheckboxConfig(PretrainedConfig):
6
  model_type = "checkbox-classifier"
@@ -9,6 +13,44 @@ class CheckboxConfig(PretrainedConfig):
9
  super().__init__(num_labels=num_labels, **kwargs)
10
  self.dropout_rate = dropout_rate
11
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  class CheckboxClassifier(PreTrainedModel):
13
  config_class = CheckboxConfig
14
 
@@ -16,7 +58,7 @@ class CheckboxClassifier(PreTrainedModel):
16
  super().__init__(config)
17
  self.num_labels = config.num_labels
18
 
19
- self.backbone = efficientnet_v2_s(weights=EfficientNet_V2_S_Weights.IMAGENET1K_V1)
20
  num_features = self.backbone.classifier[1].in_features
21
 
22
  self.backbone.classifier = nn.Sequential(
@@ -32,6 +74,15 @@ class CheckboxClassifier(PreTrainedModel):
32
  nn.Linear(256, config.num_labels)
33
  )
34
 
35
- def forward(self, pixel_values):
36
  outputs = self.backbone(pixel_values)
37
- return {"logits": outputs}
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PreTrainedModel, PretrainedConfig, ImageProcessingMixin
2
+ import torch
3
  import torch.nn as nn
4
  from torchvision.models import efficientnet_v2_s, EfficientNet_V2_S_Weights
5
+ from torchvision import transforms
6
+ from PIL import Image
7
+ import numpy as np
8
 
9
  class CheckboxConfig(PretrainedConfig):
10
  model_type = "checkbox-classifier"
 
13
  super().__init__(num_labels=num_labels, **kwargs)
14
  self.dropout_rate = dropout_rate
15
 
16
+ class CheckboxImageProcessor(ImageProcessingMixin):
17
+ """Simple image processor for checkbox classifier"""
18
+
19
+ def __init__(self, **kwargs):
20
+ super().__init__(**kwargs)
21
+ self.size = {"height": 128, "width": 128}
22
+ self.image_mean = [0.485, 0.456, 0.406]
23
+ self.image_std = [0.229, 0.224, 0.225]
24
+
25
+ self.transform = transforms.Compose([
26
+ transforms.Resize((self.size["height"], self.size["width"])),
27
+ transforms.ToTensor(),
28
+ transforms.Normalize(mean=self.image_mean, std=self.image_std)
29
+ ])
30
+
31
+ def preprocess(self, images, **kwargs):
32
+ """Preprocess images for model input"""
33
+ if not isinstance(images, list):
34
+ images = [images]
35
+
36
+ processed = []
37
+ for image in images:
38
+ if isinstance(image, str):
39
+ image = Image.open(image).convert('RGB')
40
+ elif isinstance(image, np.ndarray):
41
+ image = Image.fromarray(image).convert('RGB')
42
+ elif not isinstance(image, Image.Image):
43
+ raise ValueError(f"Unsupported image type: {type(image)}")
44
+
45
+ processed.append(self.transform(image))
46
+
47
+ # Stack into batch
48
+ pixel_values = torch.stack(processed)
49
+ return {"pixel_values": pixel_values}
50
+
51
+ def __call__(self, images, **kwargs):
52
+ return self.preprocess(images, **kwargs)
53
+
54
  class CheckboxClassifier(PreTrainedModel):
55
  config_class = CheckboxConfig
56
 
 
58
  super().__init__(config)
59
  self.num_labels = config.num_labels
60
 
61
+ self.backbone = efficientnet_v2_s(weights=None) # Don't load pretrained weights here
62
  num_features = self.backbone.classifier[1].in_features
63
 
64
  self.backbone.classifier = nn.Sequential(
 
74
  nn.Linear(256, config.num_labels)
75
  )
76
 
77
+ def forward(self, pixel_values, labels=None):
78
  outputs = self.backbone(pixel_values)
79
+
80
+ loss = None
81
+ if labels is not None:
82
+ loss_fct = nn.CrossEntropyLoss()
83
+ loss = loss_fct(outputs, labels)
84
+
85
+ return {
86
+ "loss": loss,
87
+ "logits": outputs,
88
+ }