sawyerhpowell commited on
Commit
e3267cb
·
verified ·
1 Parent(s): c8ed28c

Upload models.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. models.py +258 -0
models.py ADDED
@@ -0,0 +1,258 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from random import randint
2
+ from string import printable
3
+ import numpy as np
4
+ import torch
5
+ from rapidfuzz.distance.Levenshtein import distance as ldistance
6
+ from torch.optim import AdamW
7
+ from models import EditDistanceModel
8
+
9
+ def pad_with_null(string: str, target_length: int):
10
+ null_char = "\0"
11
+ padding_needed = max(0, target_length - len(string))
12
+ return (string + (null_char * padding_needed))[:target_length]
13
+
14
+ def string_to_tensor(string: str, length: int) -> torch.Tensor:
15
+ """Converts a string to a tensor of character indices."""
16
+ padded = pad_with_null(string, length)
17
+ # Use ord() to get integer representation, clamp to vocab size
18
+ indices = [min(ord(c), 127) for c in padded]
19
+ return torch.tensor(indices, dtype=torch.long)
20
+
21
+ def random_char() -> str:
22
+ pos = randint(0, len(printable) - 1)
23
+ return printable[pos]
24
+
25
+ def random_str(length: int) -> str:
26
+ return "".join([random_char() for _ in range(length)])
27
+
28
+ def mangle_string(source: str, d: int) -> str:
29
+ """
30
+ Efficiently mangles a string to approximately the target distance
31
+ Uses list operations for better performance
32
+ """
33
+ if d <= 0:
34
+ return source
35
+
36
+ mangled = list(source)
37
+ edits_made = 0
38
+ max_attempts = d * 3 # Prevent infinite loops
39
+ attempts = 0
40
+
41
+ while edits_made < d and attempts < max_attempts:
42
+ attempts += 1
43
+
44
+ if len(mangled) == 0:
45
+ position = 0
46
+ edit = "insert"
47
+ else:
48
+ position = randint(0, len(mangled) - 1)
49
+ edit = ["insert", "delete", "modify"][randint(0, 2)]
50
+
51
+ if edit == "insert":
52
+ mangled.insert(position, random_char())
53
+ edits_made += 1
54
+ elif edit == "modify" and len(mangled) > 0:
55
+ old_char = mangled[position]
56
+ new_char = random_char()
57
+ if old_char != new_char: # Only count as edit if actually different
58
+ mangled[position] = new_char
59
+ edits_made += 1
60
+ elif edit == "delete" and len(mangled) > 0:
61
+ mangled.pop(position)
62
+ edits_made += 1
63
+
64
+ return "".join(mangled)
65
+
66
+ def get_random_edit_distance(
67
+ minimum: int, maximum: int, mean: float, dev: float
68
+ ) -> int:
69
+ sample = np.random.normal(loc=mean, scale=dev)
70
+ sample = int(sample)
71
+ return min(max(sample, minimum), maximum)
72
+
73
+ def get_homologous_pair(
74
+ source: str, length: int
75
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
76
+ # Use more reasonable distance distribution
77
+ distance = get_random_edit_distance(1, min(length//4, 10), 3, 2)
78
+ mangled = mangle_string(source, distance)
79
+
80
+ # Verify actual distance and use it for training
81
+ actual_distance = ldistance(source, mangled)
82
+
83
+ return (
84
+ string_to_tensor(source, length),
85
+ string_to_tensor(mangled, length),
86
+ torch.tensor(float(actual_distance), dtype=torch.float),
87
+ )
88
+
89
+ def get_non_homologous_pair(
90
+ length: int,
91
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
92
+ source = random_str(length)
93
+ other = random_str(length)
94
+
95
+ # Ensure strings are actually different
96
+ max_attempts = 5
97
+ attempt = 0
98
+ while source == other and attempt < max_attempts:
99
+ other = random_str(length)
100
+ attempt += 1
101
+
102
+ distance = ldistance(source, other)
103
+
104
+ return (
105
+ string_to_tensor(source, length),
106
+ string_to_tensor(other, length),
107
+ torch.tensor(float(distance), dtype=torch.float),
108
+ )
109
+
110
+ def squared_euclidean_distance(v1: torch.Tensor, v2: torch.Tensor) -> torch.Tensor:
111
+ return torch.sum((v1 - v2) ** 2, dim=1)
112
+
113
+ def get_batch(
114
+ size: int, batch_size: int
115
+ ) -> list[tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
116
+ half_b = int(batch_size / 2)
117
+
118
+ # Generate diverse source strings for homologous pairs
119
+ h_pairs = []
120
+ for _ in range(half_b):
121
+ source = random_str(size)
122
+ h_pairs.append(get_homologous_pair(source, size))
123
+
124
+ non_h_pairs = [get_non_homologous_pair(size) for _ in range(half_b)]
125
+
126
+ # Shuffle the batch to prevent learning batch order patterns
127
+ all_pairs = h_pairs + non_h_pairs
128
+ np.random.shuffle(all_pairs)
129
+
130
+ return all_pairs
131
+
132
+ def estimate_M(length: int, num_samples: int = 1000) -> float:
133
+ """Estimates M, the average Levenshtein distance for non-homologous pairs."""
134
+ total_distance = 0.0
135
+ for _ in range(num_samples):
136
+ _, _, dist_tensor = get_non_homologous_pair(length)
137
+ total_distance += dist_tensor.item()
138
+ return total_distance / num_samples
139
+
140
+ def get_distances(
141
+ batch: list[tuple[torch.Tensor, torch.Tensor, torch.Tensor]],
142
+ model: torch.nn.Module,
143
+ M: float | None = None,
144
+ embedding_dim: int | None = None
145
+ ):
146
+ first: torch.Tensor = torch.stack([b[0] for b in batch])
147
+ first = model(first)
148
+
149
+ second: torch.Tensor = torch.stack([b[1] for b in batch])
150
+ second = model(second)
151
+
152
+ ds = torch.stack([b[2] for b in batch])
153
+
154
+ d_hats = squared_euclidean_distance(first, second)
155
+
156
+ if M is not None and embedding_dim is not None:
157
+ # r(n) = sqrt(M / (2n)) from paper Eq. 6
158
+ # We need r(n)^2 * d_hats, so (M / (2n)) * d_hats
159
+ scaling_factor_squared = M / (2 * embedding_dim)
160
+ d_hats = d_hats * scaling_factor_squared
161
+
162
+ return (d_hats, ds)
163
+
164
+ def approximation_error(d_hat: torch.Tensor, d: torch.Tensor):
165
+ return torch.mean(torch.abs(d - d_hat))
166
+
167
+ def get_loss(d_hat: torch.Tensor, d: torch.Tensor) -> torch.Tensor:
168
+ """
169
+ Wei et al. Poisson regression loss function
170
+ """
171
+ # Wei et al. Poisson regression with improved numerical stability
172
+ # PNLL(d̂, d) = d̂ - d * ln(d̂) with better handling of edge cases
173
+ epsilon = 1e-8
174
+ d_hat_stable = torch.clamp(d_hat, min=epsilon)
175
+ return torch.mean(d_hat_stable - d * torch.log(d_hat_stable))
176
+
177
+ def validate_training_data(batch: list[tuple[torch.Tensor, torch.Tensor, torch.Tensor]]) -> dict:
178
+ """Validate and analyze training batch quality"""
179
+ distances = [b[2].item() for b in batch]
180
+
181
+ stats = {
182
+ 'min_distance': min(distances),
183
+ 'max_distance': max(distances),
184
+ 'mean_distance': np.mean(distances),
185
+ 'std_distance': np.std(distances),
186
+ 'zero_distance_count': sum(1 for d in distances if d == 0),
187
+ 'high_distance_count': sum(1 for d in distances if d > 15)
188
+ }
189
+
190
+ return stats
191
+
192
+ def run_experiment(
193
+ embedding_dim: int,
194
+ model: torch.nn.Module,
195
+ learning_rate: float,
196
+ num_steps: int,
197
+ size: int,
198
+ batch_size: int,
199
+ use_gradient_clipping: bool = True,
200
+ max_grad_norm: float = 1.0,
201
+ distance_metric: str = "euclidean"
202
+ ):
203
+ """
204
+ Runs a training experiment with the given parameters and improved loss functions.
205
+ """
206
+ optimizer = AdamW(model.parameters(), lr=learning_rate, weight_decay=1e-5)
207
+ scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=200, gamma=0.5)
208
+ final_loss = 0.0
209
+ final_approx_error = 0.0
210
+
211
+ # Estimate M once at the beginning of the experiment
212
+ M_estimate = estimate_M(size)
213
+ print(f"Estimated M (average non-homologous distance): {M_estimate:.2f}")
214
+
215
+ for x in range(num_steps):
216
+ batch = get_batch(size, batch_size)
217
+
218
+ distances = get_distances(batch, model, distance_metric, M=M_estimate, embedding_dim=embedding_dim)
219
+ loss = get_loss(distances[0], distances[1])
220
+
221
+ if x % 10 == 0:
222
+ print(
223
+ f"step: {x}, loss: {loss.item()}, approx_error: {approximation_error(distances[0], distances[1]).item()}"
224
+ )
225
+
226
+ loss.backward()
227
+ optimizer.step()
228
+ scheduler.step()
229
+
230
+ final_loss = loss.item()
231
+ final_approx_error = approximation_error(distances[0], distances[1]).item()
232
+
233
+ return final_loss, final_approx_error
234
+
235
+ if __name__ == "__main__":
236
+ embedding_dim = 140
237
+
238
+ model = EditDistanceModel(embedding_dim=embedding_dim)
239
+
240
+ final_loss, final_approx_error = run_experiment(
241
+ embedding_dim=embedding_dim,
242
+ model=model,
243
+ learning_rate=0.000817,
244
+ num_steps=1000,
245
+ size=80,
246
+ batch_size=32,
247
+ use_gradient_clipping=True,
248
+ max_grad_norm=2.463,
249
+ distance_metric="euclidean",
250
+ )
251
+
252
+ print(f"Final loss: {final_loss:.4f}")
253
+ print(f"Final approximation error: {final_approx_error:.4f}")
254
+
255
+ # Save the trained model
256
+ model_path = "megashtein_trained_model.pth"
257
+ torch.save(model.state_dict(), model_path)
258
+ print(f"\n model saved to: {model_path}")