fokan commited on
Commit
141b176
·
verified ·
1 Parent(s): 1317b31

Upload model.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. model.py +69 -0
model.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Custom Student Model for Knowledge Distillation
3
+ """
4
+ import torch
5
+ import torch.nn as nn
6
+ from transformers import PreTrainedModel, PretrainedConfig
7
+ from typing import Dict, Any, List, Optional
8
+
9
+ class StudentModelConfig(PretrainedConfig):
10
+ model_type = "distilled_student"
11
+
12
+ def __init__(
13
+ self,
14
+ hidden_size=768,
15
+ num_layers=12,
16
+ num_attention_heads=12,
17
+ intermediate_size=3072,
18
+ vocab_size=30522,
19
+ max_position_embeddings=512,
20
+ modalities=["text"],
21
+ **kwargs
22
+ ):
23
+ super().__init__(**kwargs)
24
+ self.hidden_size = hidden_size
25
+ self.num_layers = num_layers
26
+ self.num_attention_heads = num_attention_heads
27
+ self.intermediate_size = intermediate_size
28
+ self.vocab_size = vocab_size
29
+ self.max_position_embeddings = max_position_embeddings
30
+ self.modalities = modalities
31
+
32
+ class StudentModel(PreTrainedModel):
33
+ config_class = StudentModelConfig
34
+
35
+ def __init__(self, config):
36
+ super().__init__(config)
37
+ self.config = config
38
+ self.hidden_size = config.hidden_size
39
+ self.num_layers = config.num_layers
40
+ self.modalities = config.modalities
41
+
42
+ # Build model layers based on config
43
+ self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size)
44
+ self.layers = nn.ModuleList([
45
+ nn.TransformerEncoderLayer(
46
+ d_model=config.hidden_size,
47
+ nhead=config.num_attention_heads,
48
+ dim_feedforward=config.intermediate_size,
49
+ batch_first=True
50
+ ) for _ in range(config.num_layers)
51
+ ])
52
+ self.pooler = nn.Linear(config.hidden_size, config.hidden_size)
53
+
54
+ def forward(self, input_ids=None, attention_mask=None, **kwargs):
55
+ if input_ids is not None:
56
+ embeddings = self.embeddings(input_ids)
57
+ else:
58
+ # Handle other modalities
59
+ embeddings = kwargs.get('inputs_embeds')
60
+
61
+ for layer in self.layers:
62
+ embeddings = layer(embeddings, src_key_padding_mask=attention_mask)
63
+
64
+ pooled = self.pooler(embeddings.mean(dim=1))
65
+
66
+ return {
67
+ 'last_hidden_state': embeddings,
68
+ 'pooler_output': pooled
69
+ }