sawyerhpowell commited on
Commit
a3d0fa4
·
verified ·
1 Parent(s): 2441092

Upload models.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. models.py +63 -256
models.py CHANGED
@@ -1,258 +1,65 @@
1
- from random import randint
2
- from string import printable
3
- import numpy as np
4
  import torch
5
- from rapidfuzz.distance.Levenshtein import distance as ldistance
6
- from torch.optim import AdamW
7
- from models import EditDistanceModel
8
 
9
- def pad_with_null(string: str, target_length: int):
10
- null_char = "\0"
11
- padding_needed = max(0, target_length - len(string))
12
- return (string + (null_char * padding_needed))[:target_length]
13
-
14
- def string_to_tensor(string: str, length: int) -> torch.Tensor:
15
- """Converts a string to a tensor of character indices."""
16
- padded = pad_with_null(string, length)
17
- # Use ord() to get integer representation, clamp to vocab size
18
- indices = [min(ord(c), 127) for c in padded]
19
- return torch.tensor(indices, dtype=torch.long)
20
-
21
- def random_char() -> str:
22
- pos = randint(0, len(printable) - 1)
23
- return printable[pos]
24
-
25
- def random_str(length: int) -> str:
26
- return "".join([random_char() for _ in range(length)])
27
-
28
- def mangle_string(source: str, d: int) -> str:
29
- """
30
- Efficiently mangles a string to approximately the target distance
31
- Uses list operations for better performance
32
- """
33
- if d <= 0:
34
- return source
35
-
36
- mangled = list(source)
37
- edits_made = 0
38
- max_attempts = d * 3 # Prevent infinite loops
39
- attempts = 0
40
-
41
- while edits_made < d and attempts < max_attempts:
42
- attempts += 1
43
-
44
- if len(mangled) == 0:
45
- position = 0
46
- edit = "insert"
47
- else:
48
- position = randint(0, len(mangled) - 1)
49
- edit = ["insert", "delete", "modify"][randint(0, 2)]
50
-
51
- if edit == "insert":
52
- mangled.insert(position, random_char())
53
- edits_made += 1
54
- elif edit == "modify" and len(mangled) > 0:
55
- old_char = mangled[position]
56
- new_char = random_char()
57
- if old_char != new_char: # Only count as edit if actually different
58
- mangled[position] = new_char
59
- edits_made += 1
60
- elif edit == "delete" and len(mangled) > 0:
61
- mangled.pop(position)
62
- edits_made += 1
63
-
64
- return "".join(mangled)
65
-
66
- def get_random_edit_distance(
67
- minimum: int, maximum: int, mean: float, dev: float
68
- ) -> int:
69
- sample = np.random.normal(loc=mean, scale=dev)
70
- sample = int(sample)
71
- return min(max(sample, minimum), maximum)
72
-
73
- def get_homologous_pair(
74
- source: str, length: int
75
- ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
76
- # Use more reasonable distance distribution
77
- distance = get_random_edit_distance(1, min(length//4, 10), 3, 2)
78
- mangled = mangle_string(source, distance)
79
-
80
- # Verify actual distance and use it for training
81
- actual_distance = ldistance(source, mangled)
82
-
83
- return (
84
- string_to_tensor(source, length),
85
- string_to_tensor(mangled, length),
86
- torch.tensor(float(actual_distance), dtype=torch.float),
87
- )
88
-
89
- def get_non_homologous_pair(
90
- length: int,
91
- ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
92
- source = random_str(length)
93
- other = random_str(length)
94
-
95
- # Ensure strings are actually different
96
- max_attempts = 5
97
- attempt = 0
98
- while source == other and attempt < max_attempts:
99
- other = random_str(length)
100
- attempt += 1
101
-
102
- distance = ldistance(source, other)
103
-
104
- return (
105
- string_to_tensor(source, length),
106
- string_to_tensor(other, length),
107
- torch.tensor(float(distance), dtype=torch.float),
108
- )
109
-
110
- def squared_euclidean_distance(v1: torch.Tensor, v2: torch.Tensor) -> torch.Tensor:
111
- return torch.sum((v1 - v2) ** 2, dim=1)
112
-
113
- def get_batch(
114
- size: int, batch_size: int
115
- ) -> list[tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
116
- half_b = int(batch_size / 2)
117
-
118
- # Generate diverse source strings for homologous pairs
119
- h_pairs = []
120
- for _ in range(half_b):
121
- source = random_str(size)
122
- h_pairs.append(get_homologous_pair(source, size))
123
-
124
- non_h_pairs = [get_non_homologous_pair(size) for _ in range(half_b)]
125
-
126
- # Shuffle the batch to prevent learning batch order patterns
127
- all_pairs = h_pairs + non_h_pairs
128
- np.random.shuffle(all_pairs)
129
-
130
- return all_pairs
131
-
132
- def estimate_M(length: int, num_samples: int = 1000) -> float:
133
- """Estimates M, the average Levenshtein distance for non-homologous pairs."""
134
- total_distance = 0.0
135
- for _ in range(num_samples):
136
- _, _, dist_tensor = get_non_homologous_pair(length)
137
- total_distance += dist_tensor.item()
138
- return total_distance / num_samples
139
-
140
- def get_distances(
141
- batch: list[tuple[torch.Tensor, torch.Tensor, torch.Tensor]],
142
- model: torch.nn.Module,
143
- M: float | None = None,
144
- embedding_dim: int | None = None
145
- ):
146
- first: torch.Tensor = torch.stack([b[0] for b in batch])
147
- first = model(first)
148
-
149
- second: torch.Tensor = torch.stack([b[1] for b in batch])
150
- second = model(second)
151
-
152
- ds = torch.stack([b[2] for b in batch])
153
-
154
- d_hats = squared_euclidean_distance(first, second)
155
-
156
- if M is not None and embedding_dim is not None:
157
- # r(n) = sqrt(M / (2n)) from paper Eq. 6
158
- # We need r(n)^2 * d_hats, so (M / (2n)) * d_hats
159
- scaling_factor_squared = M / (2 * embedding_dim)
160
- d_hats = d_hats * scaling_factor_squared
161
-
162
- return (d_hats, ds)
163
-
164
- def approximation_error(d_hat: torch.Tensor, d: torch.Tensor):
165
- return torch.mean(torch.abs(d - d_hat))
166
-
167
- def get_loss(d_hat: torch.Tensor, d: torch.Tensor) -> torch.Tensor:
168
- """
169
- Wei et al. Poisson regression loss function
170
- """
171
- # Wei et al. Poisson regression with improved numerical stability
172
- # PNLL(d̂, d) = d̂ - d * ln(d̂) with better handling of edge cases
173
- epsilon = 1e-8
174
- d_hat_stable = torch.clamp(d_hat, min=epsilon)
175
- return torch.mean(d_hat_stable - d * torch.log(d_hat_stable))
176
-
177
- def validate_training_data(batch: list[tuple[torch.Tensor, torch.Tensor, torch.Tensor]]) -> dict:
178
- """Validate and analyze training batch quality"""
179
- distances = [b[2].item() for b in batch]
180
-
181
- stats = {
182
- 'min_distance': min(distances),
183
- 'max_distance': max(distances),
184
- 'mean_distance': np.mean(distances),
185
- 'std_distance': np.std(distances),
186
- 'zero_distance_count': sum(1 for d in distances if d == 0),
187
- 'high_distance_count': sum(1 for d in distances if d > 15)
188
- }
189
-
190
- return stats
191
-
192
- def run_experiment(
193
- embedding_dim: int,
194
- model: torch.nn.Module,
195
- learning_rate: float,
196
- num_steps: int,
197
- size: int,
198
- batch_size: int,
199
- use_gradient_clipping: bool = True,
200
- max_grad_norm: float = 1.0,
201
- distance_metric: str = "euclidean"
202
- ):
203
- """
204
- Runs a training experiment with the given parameters and improved loss functions.
205
- """
206
- optimizer = AdamW(model.parameters(), lr=learning_rate, weight_decay=1e-5)
207
- scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=200, gamma=0.5)
208
- final_loss = 0.0
209
- final_approx_error = 0.0
210
-
211
- # Estimate M once at the beginning of the experiment
212
- M_estimate = estimate_M(size)
213
- print(f"Estimated M (average non-homologous distance): {M_estimate:.2f}")
214
-
215
- for x in range(num_steps):
216
- batch = get_batch(size, batch_size)
217
-
218
- distances = get_distances(batch, model, distance_metric, M=M_estimate, embedding_dim=embedding_dim)
219
- loss = get_loss(distances[0], distances[1])
220
-
221
- if x % 10 == 0:
222
- print(
223
- f"step: {x}, loss: {loss.item()}, approx_error: {approximation_error(distances[0], distances[1]).item()}"
224
- )
225
-
226
- loss.backward()
227
- optimizer.step()
228
- scheduler.step()
229
-
230
- final_loss = loss.item()
231
- final_approx_error = approximation_error(distances[0], distances[1]).item()
232
-
233
- return final_loss, final_approx_error
234
-
235
- if __name__ == "__main__":
236
- embedding_dim = 140
237
-
238
- model = EditDistanceModel(embedding_dim=embedding_dim)
239
-
240
- final_loss, final_approx_error = run_experiment(
241
- embedding_dim=embedding_dim,
242
- model=model,
243
- learning_rate=0.000817,
244
- num_steps=1000,
245
- size=80,
246
- batch_size=32,
247
- use_gradient_clipping=True,
248
- max_grad_norm=2.463,
249
- distance_metric="euclidean",
250
- )
251
-
252
- print(f"Final loss: {final_loss:.4f}")
253
- print(f"Final approximation error: {final_approx_error:.4f}")
254
-
255
- # Save the trained model
256
- model_path = "megashtein_trained_model.pth"
257
- torch.save(model.state_dict(), model_path)
258
- print(f"\n model saved to: {model_path}")
 
 
 
 
1
  import torch
 
 
 
2
 
3
+ class EditDistanceModel(torch.nn.Module):
4
+
5
+ def __init__(self, vocab_size=128, embedding_dim=16, input_length=80):
6
+ super().__init__()
7
+ self.embedding = torch.nn.Embedding(vocab_size, embedding_dim)
8
+
9
+ self.conv_layers = torch.nn.Sequential(
10
+ torch.nn.Conv1d(embedding_dim, 64, 3, 1, 1),
11
+ torch.nn.AvgPool1d(2),
12
+ torch.nn.ReLU(),
13
+ torch.nn.Conv1d(64, 64, 3, 1, 1),
14
+ torch.nn.AvgPool1d(2),
15
+ torch.nn.ReLU(),
16
+ torch.nn.Conv1d(64, 64, 3, 1, 1),
17
+ torch.nn.AvgPool1d(2),
18
+ torch.nn.ReLU(),
19
+ torch.nn.Conv1d(64, 64, 3, 1, 1),
20
+ torch.nn.AvgPool1d(2),
21
+ torch.nn.ReLU(),
22
+ torch.nn.Conv1d(64, 64, 3, 1, 1),
23
+ torch.nn.AvgPool1d(2),
24
+ torch.nn.ReLU(),
25
+ )
26
+
27
+ self.flatten = torch.nn.Flatten()
28
+
29
+ with torch.no_grad():
30
+ dummy_input = torch.zeros(1, input_length, dtype=torch.long)
31
+ dummy_embedded = self.embedding(dummy_input)
32
+ dummy_permuted = dummy_embedded.permute(0, 2, 1)
33
+ dummy_conved = self.conv_layers(dummy_permuted)
34
+ flattened_size = self.flatten(dummy_conved).shape[1]
35
+
36
+ self.fc_layers = torch.nn.Sequential(
37
+ torch.nn.Linear(flattened_size, 200),
38
+ torch.nn.ReLU(),
39
+ torch.nn.Linear(200, 80),
40
+ torch.nn.BatchNorm1d(80),
41
+ )
42
+
43
+ self._initialize_weights()
44
+
45
+ def _initialize_weights(self):
46
+ for module in self.modules():
47
+ if isinstance(module, torch.nn.Linear):
48
+ torch.nn.init.xavier_uniform_(module.weight)
49
+ if module.bias is not None:
50
+ torch.nn.init.zeros_(module.bias)
51
+ elif isinstance(module, torch.nn.Conv1d):
52
+ torch.nn.init.xavier_uniform_(module.weight)
53
+ if module.bias is not None:
54
+ torch.nn.init.zeros_(module.bias)
55
+ elif isinstance(module, torch.nn.BatchNorm1d):
56
+ torch.nn.init.ones_(module.weight)
57
+ torch.nn.init.zeros_(module.bias)
58
+
59
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
60
+ x = self.embedding(x)
61
+ x = x.permute(0, 2, 1)
62
+ x = self.conv_layers(x)
63
+ x = self.flatten(x)
64
+ x = self.fc_layers(x)
65
+ return x