HemanM commited on
Commit
cdd7a15
·
verified ·
1 Parent(s): 896bcee

Update evo_model.py

Browse files
Files changed (1) hide show
  1. evo_model.py +20 -13
evo_model.py CHANGED
@@ -1,8 +1,9 @@
 
 
1
  import torch
2
  import torch.nn as nn
3
  from transformers import PreTrainedModel, PretrainedConfig
4
 
5
-
6
  class EvoTransformerConfig(PretrainedConfig):
7
  def __init__(
8
  self,
@@ -30,7 +31,7 @@ class EvoTransformerForClassification(PreTrainedModel):
30
  super().__init__(config)
31
  self.config = config
32
 
33
- # Expose architecture traits for dashboard or mutation
34
  self.num_layers = config.num_layers
35
  self.num_heads = config.num_heads
36
  self.ffn_dim = config.ffn_dim
@@ -42,7 +43,8 @@ class EvoTransformerForClassification(PreTrainedModel):
42
  nn.TransformerEncoderLayer(
43
  d_model=config.hidden_size,
44
  nhead=config.num_heads,
45
- dim_feedforward=config.ffn_dim
 
46
  )
47
  for _ in range(config.num_layers)
48
  ])
@@ -56,32 +58,37 @@ class EvoTransformerForClassification(PreTrainedModel):
56
  self.init_weights()
57
 
58
  def forward(self, input_ids, attention_mask=None, labels=None):
59
- x = self.embedding(input_ids) # [batch, seq_len, hidden_size]
60
- x = x.transpose(0, 1) # Transformer expects [seq_len, batch, hidden_size]
 
 
 
61
 
62
  for layer in self.layers:
63
- x = layer(x, src_key_padding_mask=(attention_mask == 0) if attention_mask is not None else None)
64
 
65
- x = x.mean(dim=0) # mean pooling
66
  logits = self.classifier(x)
67
 
68
  if labels is not None:
69
  loss = nn.functional.cross_entropy(logits, labels)
70
  return loss, logits
 
71
  return logits
72
 
73
  def save_pretrained(self, save_directory):
74
- import os, json
75
  os.makedirs(save_directory, exist_ok=True)
76
- torch.save(self.state_dict(), f"{save_directory}/pytorch_model.bin")
77
- with open(f"{save_directory}/config.json", "w") as f:
78
  f.write(self.config.to_json_string())
79
 
80
  @classmethod
81
  def from_pretrained(cls, load_directory):
82
- config_path = f"{load_directory}/config.json"
83
- model_path = f"{load_directory}/pytorch_model.bin"
84
  config = EvoTransformerConfig.from_json_file(config_path)
85
  model = cls(config)
86
- model.load_state_dict(torch.load(model_path, map_location="cpu"))
 
 
87
  return model
 
1
+ import os
2
+ import json
3
  import torch
4
  import torch.nn as nn
5
  from transformers import PreTrainedModel, PretrainedConfig
6
 
 
7
  class EvoTransformerConfig(PretrainedConfig):
8
  def __init__(
9
  self,
 
31
  super().__init__(config)
32
  self.config = config
33
 
34
+ # === Architecture traits for UI, mutation, etc.
35
  self.num_layers = config.num_layers
36
  self.num_heads = config.num_heads
37
  self.ffn_dim = config.ffn_dim
 
43
  nn.TransformerEncoderLayer(
44
  d_model=config.hidden_size,
45
  nhead=config.num_heads,
46
+ dim_feedforward=config.ffn_dim,
47
+ batch_first=False # Required for transpose trick
48
  )
49
  for _ in range(config.num_layers)
50
  ])
 
58
  self.init_weights()
59
 
60
  def forward(self, input_ids, attention_mask=None, labels=None):
61
+ # Embedding and prep for transformer
62
+ x = self.embedding(input_ids) # [batch, seq_len, hidden]
63
+ x = x.transpose(0, 1) # [seq_len, batch, hidden]
64
+
65
+ key_padding_mask = (attention_mask == 0) if attention_mask is not None else None
66
 
67
  for layer in self.layers:
68
+ x = layer(x, src_key_padding_mask=key_padding_mask)
69
 
70
+ x = x.mean(dim=0) # [batch, hidden] — mean pooling
71
  logits = self.classifier(x)
72
 
73
  if labels is not None:
74
  loss = nn.functional.cross_entropy(logits, labels)
75
  return loss, logits
76
+
77
  return logits
78
 
79
  def save_pretrained(self, save_directory):
 
80
  os.makedirs(save_directory, exist_ok=True)
81
+ torch.save(self.state_dict(), os.path.join(save_directory, "pytorch_model.bin"))
82
+ with open(os.path.join(save_directory, "config.json"), "w") as f:
83
  f.write(self.config.to_json_string())
84
 
85
  @classmethod
86
  def from_pretrained(cls, load_directory):
87
+ config_path = os.path.join(load_directory, "config.json")
88
+ model_path = os.path.join(load_directory, "pytorch_model.bin")
89
  config = EvoTransformerConfig.from_json_file(config_path)
90
  model = cls(config)
91
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
92
+ model.load_state_dict(torch.load(model_path, map_location=device))
93
+ model.to(device)
94
  return model