adriennehoarfrost commited on
Commit
2eef7fa
·
verified ·
1 Parent(s): a3061d8

Upload folder using huggingface_hub

Browse files
Files changed (5) hide show
  1. README.md +88 -0
  2. config.json +29 -0
  3. lookingglass.py +780 -0
  4. lookingglass_classifier.py +258 -0
  5. pytorch_model.bin +3 -0
README.md ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ language:
3
+ - en
4
+ tags:
5
+ - biology
6
+ - dna
7
+ - genomics
8
+ - metagenomics
9
+ - classifier
10
+ - awd-lstm
11
+ - transfer-learning
12
+ license: mit
13
+ pipeline_tag: text-classification
14
+ library_name: pytorch
15
+ ---
16
+
17
+ # LookingGlass Reading Frame Classifier
18
+
19
+ Identifies the correct reading frame start position (1, 2, 3, -1, -2, or -3) for DNA reads. Note: currently only intended for prokaryotic sequences with low proportions of noncoding DNA.
20
+
21
+ This is a **pure PyTorch implementation** fine-tuned from the LookingGlass base model.
22
+
23
+ ## Links
24
+
25
+ - **Paper**: [Deep learning of a bacterial and archaeal universal language of life enables transfer learning and illuminates microbial dark matter](https://doi.org/10.1038/s41467-022-30070-8) (Nature Communications, 2022)
26
+ - **GitHub**: [ahoarfrost/LookingGlass](https://github.com/ahoarfrost/LookingGlass)
27
+ - **Base Model**: [HoarfrostLab/lookingglass-v1](https://huggingface.co/HoarfrostLab/lookingglass-v1)
28
+
29
+ ## Citation
30
+
31
+ ```bibtex
32
+ @article{hoarfrost2022deep,
33
+ title={Deep learning of a bacterial and archaeal universal language of life
34
+ enables transfer learning and illuminates microbial dark matter},
35
+ author={Hoarfrost, Adrienne and Aptekmann, Ariel and Farfanuk, Gaetan and Bromberg, Yana},
36
+ journal={Nature Communications},
37
+ volume={13},
38
+ number={1},
39
+ pages={2606},
40
+ year={2022},
41
+ publisher={Nature Publishing Group}
42
+ }
43
+ ```
44
+
45
+ ## Model
46
+
47
+ | | |
48
+ |---|---|
49
+ | Architecture | LookingGlass encoder + classification head |
50
+ | Encoder | AWD-LSTM (3-layer, unidirectional) |
51
+ | Classes | 6 classes: 1, 2, 3, -1, -2, -3 |
52
+ | Parameters | ~17M |
53
+
54
+ ## Installation
55
+
56
+ ```bash
57
+ pip install torch
58
+ git clone https://huggingface.co/HoarfrostLab/LGv1_ReadingFrameClassifier
59
+ cd LGv1_ReadingFrameClassifier
60
+ ```
61
+
62
+ ## Usage
63
+
64
+ ```python
65
+ from lookingglass_classifier import LookingGlassClassifier, LookingGlassTokenizer
66
+
67
+ model = LookingGlassClassifier.from_pretrained('.')
68
+ tokenizer = LookingGlassTokenizer()
69
+ model.eval()
70
+
71
+ inputs = tokenizer(["GATTACA", "ATCGATCGATCG"], return_tensors=True)
72
+
73
+ # Get predictions
74
+ predictions = model.predict(inputs['input_ids'])
75
+ print(predictions) # tensor([class_idx, class_idx])
76
+
77
+ # Get probabilities
78
+ probs = model.predict_proba(inputs['input_ids'])
79
+ print(probs.shape) # torch.Size([2, 6])
80
+
81
+ # Get raw logits
82
+ logits = model(inputs['input_ids'])
83
+ print(logits.shape) # torch.Size([2, 6])
84
+ ```
85
+
86
+ ## License
87
+
88
+ MIT License
config.json ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "vocab_size": 8,
3
+ "hidden_size": 104,
4
+ "intermediate_size": 1152,
5
+ "num_hidden_layers": 3,
6
+ "pad_token_id": 1,
7
+ "bos_token_id": 2,
8
+ "eos_token_id": 3,
9
+ "bidirectional": false,
10
+ "output_dropout": 0.1,
11
+ "hidden_dropout": 0.15,
12
+ "input_dropout": 0.25,
13
+ "embed_dropout": 0.02,
14
+ "weight_dropout": 0.2,
15
+ "tie_weights": true,
16
+ "output_bias": true,
17
+ "model_type": "lookingglass",
18
+ "num_classes": 6,
19
+ "classifier_hidden": 50,
20
+ "classifier_dropout": 0.0,
21
+ "class_names": [
22
+ "1",
23
+ "2",
24
+ "3",
25
+ "-1",
26
+ "-2",
27
+ "-3"
28
+ ]
29
+ }
lookingglass.py ADDED
@@ -0,0 +1,780 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ LookingGlass - A DNA Language Model
3
+
4
+ Pure PyTorch implementation of LookingGlass, a pretrained language model for DNA sequences.
5
+ Based on AWD-LSTM architecture, originally trained with fastai v1.
6
+
7
+ Paper: Hoarfrost et al., "Deep learning of a bacterial and archaeal universal language
8
+ of life enables transfer learning and illuminates microbial dark matter",
9
+ Nature Communications, 2022.
10
+
11
+ Usage:
12
+ from lookingglass import LookingGlass, LookingGlassTokenizer
13
+
14
+ # Load from HuggingFace Hub
15
+ model = LookingGlass.from_pretrained('HoarfrostLab/lookingglass-v1')
16
+ tokenizer = LookingGlassTokenizer()
17
+
18
+ # Or load from local path
19
+ model = LookingGlass.from_pretrained('./lookingglass-v1')
20
+
21
+ inputs = tokenizer(["GATTACA", "ATCGATCG"], return_tensors=True)
22
+ embeddings = model.get_embeddings(inputs['input_ids']) # (batch, 104)
23
+ """
24
+
25
+ import json
26
+ import os
27
+ import warnings
28
+ from dataclasses import dataclass, asdict
29
+ from typing import Optional, Tuple, List, Dict, Union
30
+
31
+ import torch
32
+ import torch.nn as nn
33
+ import torch.nn.functional as F
34
+
35
+ try:
36
+ from huggingface_hub import hf_hub_download
37
+ HF_HUB_AVAILABLE = True
38
+ except ImportError:
39
+ HF_HUB_AVAILABLE = False
40
+
41
+
42
+ __version__ = "1.1.0"
43
+
44
+
45
+ def _is_hf_hub_id(path: str) -> bool:
46
+ """Check if path looks like a HuggingFace Hub model ID (e.g., 'user/model')."""
47
+ if os.path.exists(path):
48
+ return False
49
+ return '/' in path and not path.startswith(('.', '/'))
50
+
51
+
52
+ def _download_from_hub(repo_id: str, filename: str) -> str:
53
+ """Download a file from HuggingFace Hub and return the local path."""
54
+ if not HF_HUB_AVAILABLE:
55
+ raise ImportError(
56
+ "huggingface_hub is required to load models from the Hub. "
57
+ "Install it with: pip install huggingface_hub"
58
+ )
59
+ return hf_hub_download(repo_id=repo_id, filename=filename)
60
+ __all__ = [
61
+ "LookingGlassConfig",
62
+ "LookingGlass",
63
+ "LookingGlassLM",
64
+ "LookingGlassTokenizer",
65
+ ]
66
+
67
+
68
+ # =============================================================================
69
+ # Configuration
70
+ # =============================================================================
71
+
72
+ @dataclass
73
+ class LookingGlassConfig:
74
+ """
75
+ Configuration for LookingGlass model.
76
+
77
+ Default values match the original pretrained LookingGlass model.
78
+ """
79
+ vocab_size: int = 8
80
+ hidden_size: int = 104 # embedding/output size
81
+ intermediate_size: int = 1152 # LSTM hidden size
82
+ num_hidden_layers: int = 3
83
+ pad_token_id: int = 1
84
+ bos_token_id: int = 2
85
+ eos_token_id: int = 3
86
+ bidirectional: bool = False # original LG is unidirectional
87
+ output_dropout: float = 0.1
88
+ hidden_dropout: float = 0.15
89
+ input_dropout: float = 0.25
90
+ embed_dropout: float = 0.02
91
+ weight_dropout: float = 0.2
92
+ tie_weights: bool = True
93
+ output_bias: bool = True
94
+ model_type: str = "lookingglass"
95
+
96
+ def to_dict(self) -> Dict:
97
+ return asdict(self)
98
+
99
+ def save_pretrained(self, save_directory: str):
100
+ os.makedirs(save_directory, exist_ok=True)
101
+ with open(os.path.join(save_directory, "config.json"), 'w') as f:
102
+ json.dump(self.to_dict(), f, indent=2)
103
+
104
+ @classmethod
105
+ def from_pretrained(cls, pretrained_path: str) -> "LookingGlassConfig":
106
+ if _is_hf_hub_id(pretrained_path):
107
+ try:
108
+ config_path = _download_from_hub(pretrained_path, "config.json")
109
+ except Exception:
110
+ return cls()
111
+ elif os.path.isdir(pretrained_path):
112
+ config_path = os.path.join(pretrained_path, "config.json")
113
+ else:
114
+ config_path = pretrained_path
115
+
116
+ if os.path.exists(config_path):
117
+ with open(config_path, 'r') as f:
118
+ config_dict = json.load(f)
119
+ valid_fields = {f.name for f in cls.__dataclass_fields__.values()}
120
+ return cls(**{k: v for k, v in config_dict.items() if k in valid_fields})
121
+ return cls()
122
+
123
+
124
+ # =============================================================================
125
+ # Tokenizer
126
+ # =============================================================================
127
+
128
+ VOCAB = ['xxunk', 'xxpad', 'xxbos', 'xxeos', 'G', 'A', 'C', 'T']
129
+ VOCAB_TO_ID = {tok: i for i, tok in enumerate(VOCAB)}
130
+ ID_TO_VOCAB = {i: tok for i, tok in enumerate(VOCAB)}
131
+
132
+
133
+ class LookingGlassTokenizer:
134
+ """
135
+ Tokenizer for DNA sequences.
136
+
137
+ Each nucleotide (G, A, C, T) is a single token. By default, adds BOS token
138
+ at the start of each sequence (matching original LookingGlass training).
139
+
140
+ Special tokens:
141
+ - xxunk (0): Unknown
142
+ - xxpad (1): Padding
143
+ - xxbos (2): Beginning of sequence
144
+ - xxeos (3): End of sequence
145
+ """
146
+
147
+ vocab = VOCAB
148
+ vocab_to_id = VOCAB_TO_ID
149
+ id_to_vocab = ID_TO_VOCAB
150
+
151
+ def __init__(
152
+ self,
153
+ add_bos_token: bool = True, # original LG uses BOS
154
+ add_eos_token: bool = False, # original LG does not use EOS
155
+ padding_side: str = "right",
156
+ ):
157
+ self.add_bos_token = add_bos_token
158
+ self.add_eos_token = add_eos_token
159
+ self.padding_side = padding_side
160
+
161
+ self.unk_token_id = 0
162
+ self.pad_token_id = 1
163
+ self.bos_token_id = 2
164
+ self.eos_token_id = 3
165
+
166
+ @property
167
+ def vocab_size(self) -> int:
168
+ return len(self.vocab)
169
+
170
+ def encode(self, sequence: str, add_special_tokens: bool = True) -> List[int]:
171
+ """Encode a DNA sequence to token IDs."""
172
+ tokens = []
173
+
174
+ if add_special_tokens and self.add_bos_token:
175
+ tokens.append(self.bos_token_id)
176
+
177
+ for char in sequence.upper():
178
+ if char in self.vocab_to_id:
179
+ tokens.append(self.vocab_to_id[char])
180
+ elif char.strip():
181
+ tokens.append(self.unk_token_id)
182
+
183
+ if add_special_tokens and self.add_eos_token:
184
+ tokens.append(self.eos_token_id)
185
+
186
+ return tokens
187
+
188
+ def decode(self, token_ids: Union[List[int], torch.Tensor], skip_special_tokens: bool = True) -> str:
189
+ """Decode token IDs back to DNA sequence."""
190
+ if isinstance(token_ids, torch.Tensor):
191
+ token_ids = token_ids.tolist()
192
+
193
+ special_ids = {0, 1, 2, 3}
194
+ tokens = []
195
+ for tid in token_ids:
196
+ if skip_special_tokens and tid in special_ids:
197
+ continue
198
+ tokens.append(self.id_to_vocab.get(tid, 'xxunk'))
199
+ return ''.join(tokens)
200
+
201
+ def __call__(
202
+ self,
203
+ sequences: Union[str, List[str]],
204
+ padding: Union[bool, str] = False,
205
+ max_length: Optional[int] = None,
206
+ truncation: bool = False,
207
+ return_tensors: Union[bool, str] = False,
208
+ return_attention_mask: bool = True,
209
+ ) -> Dict[str, torch.Tensor]:
210
+ """Tokenize DNA sequence(s)."""
211
+ if isinstance(sequences, str):
212
+ sequences = [sequences]
213
+ single = True
214
+ else:
215
+ single = False
216
+
217
+ encoded = [self.encode(seq) for seq in sequences]
218
+
219
+ if truncation and max_length:
220
+ encoded = [e[:max_length] for e in encoded]
221
+
222
+ # Padding
223
+ if padding or len(encoded) > 1:
224
+ if padding == 'max_length' and max_length:
225
+ pad_len = max_length
226
+ else:
227
+ pad_len = max(len(e) for e in encoded)
228
+
229
+ padded = []
230
+ masks = []
231
+ for e in encoded:
232
+ pad_amount = pad_len - len(e)
233
+ mask = [1] * len(e) + [0] * pad_amount
234
+ if self.padding_side == 'right':
235
+ e = e + [self.pad_token_id] * pad_amount
236
+ else:
237
+ e = [self.pad_token_id] * pad_amount + e
238
+ mask = [0] * pad_amount + [1] * len(e)
239
+ padded.append(e)
240
+ masks.append(mask)
241
+ encoded = padded
242
+ else:
243
+ masks = [[1] * len(e) for e in encoded]
244
+
245
+ result = {}
246
+ if return_tensors in ('pt', True):
247
+ result['input_ids'] = torch.tensor(encoded, dtype=torch.long)
248
+ if return_attention_mask:
249
+ result['attention_mask'] = torch.tensor(masks, dtype=torch.long)
250
+ else:
251
+ result['input_ids'] = encoded[0] if single else encoded
252
+ if return_attention_mask:
253
+ result['attention_mask'] = masks[0] if single else masks
254
+
255
+ return result
256
+
257
+ def save_pretrained(self, save_directory: str):
258
+ os.makedirs(save_directory, exist_ok=True)
259
+ with open(os.path.join(save_directory, "vocab.json"), 'w') as f:
260
+ json.dump(self.vocab_to_id, f, indent=2)
261
+ with open(os.path.join(save_directory, "tokenizer_config.json"), 'w') as f:
262
+ json.dump({
263
+ "add_bos_token": self.add_bos_token,
264
+ "add_eos_token": self.add_eos_token,
265
+ "padding_side": self.padding_side,
266
+ }, f, indent=2)
267
+
268
+ @classmethod
269
+ def from_pretrained(cls, pretrained_path: str) -> "LookingGlassTokenizer":
270
+ kwargs = {}
271
+ if _is_hf_hub_id(pretrained_path):
272
+ try:
273
+ config_path = _download_from_hub(pretrained_path, "tokenizer_config.json")
274
+ with open(config_path, 'r') as f:
275
+ kwargs = json.load(f)
276
+ except Exception:
277
+ pass
278
+ else:
279
+ config_path = os.path.join(pretrained_path, "tokenizer_config.json")
280
+ if os.path.exists(config_path):
281
+ with open(config_path, 'r') as f:
282
+ kwargs = json.load(f)
283
+ return cls(**kwargs)
284
+
285
+
286
+ # =============================================================================
287
+ # Model Components
288
+ # =============================================================================
289
+
290
+ def _dropout_mask(x: torch.Tensor, size: Tuple[int, ...], p: float) -> torch.Tensor:
291
+ """Create dropout mask with inverted scaling."""
292
+ return x.new_empty(*size).bernoulli_(1 - p).div_(1 - p)
293
+
294
+
295
+ class _RNNDropout(nn.Module):
296
+ """Dropout consistent across sequence dimension."""
297
+
298
+ def __init__(self, p: float = 0.5):
299
+ super().__init__()
300
+ self.p = p
301
+
302
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
303
+ if not self.training or self.p == 0.:
304
+ return x
305
+ mask = _dropout_mask(x.data, (x.size(0), 1, x.size(2)), self.p)
306
+ return x * mask
307
+
308
+
309
+ class _EmbeddingDropout(nn.Module):
310
+ """Dropout applied to entire embedding rows."""
311
+
312
+ def __init__(self, embedding: nn.Embedding, p: float):
313
+ super().__init__()
314
+ self.embedding = embedding
315
+ self.p = p
316
+
317
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
318
+ if self.training and self.p != 0:
319
+ mask = _dropout_mask(self.embedding.weight.data,
320
+ (self.embedding.weight.size(0), 1), self.p)
321
+ masked_weight = self.embedding.weight * mask
322
+ else:
323
+ masked_weight = self.embedding.weight
324
+
325
+ padding_idx = self.embedding.padding_idx if self.embedding.padding_idx is not None else -1
326
+ return F.embedding(x, masked_weight, padding_idx,
327
+ self.embedding.max_norm, self.embedding.norm_type,
328
+ self.embedding.scale_grad_by_freq, self.embedding.sparse)
329
+
330
+
331
+ class _WeightDropout(nn.Module):
332
+ """DropConnect applied to RNN hidden-to-hidden weights."""
333
+
334
+ def __init__(self, module: nn.Module, p: float, layer_names='weight_hh_l0'):
335
+ super().__init__()
336
+ self.module = module
337
+ self.p = p
338
+ self.layer_names = [layer_names] if isinstance(layer_names, str) else layer_names
339
+
340
+ for layer in self.layer_names:
341
+ w = getattr(self.module, layer)
342
+ delattr(self.module, layer)
343
+ self.register_parameter(f'{layer}_raw', nn.Parameter(w.data))
344
+ setattr(self.module, layer, w.clone())
345
+
346
+ if isinstance(self.module, nn.RNNBase):
347
+ self.module.flatten_parameters = lambda: None
348
+
349
+ def _set_weights(self):
350
+ for layer in self.layer_names:
351
+ raw_w = getattr(self, f'{layer}_raw')
352
+ w = F.dropout(raw_w, p=self.p, training=self.training) if self.training else raw_w.clone()
353
+ setattr(self.module, layer, w)
354
+
355
+ def forward(self, *args):
356
+ self._set_weights()
357
+ with warnings.catch_warnings():
358
+ warnings.simplefilter("ignore", category=UserWarning)
359
+ return self.module(*args)
360
+
361
+
362
+ class _AWDLSTMEncoder(nn.Module):
363
+ """AWD-LSTM encoder backbone."""
364
+
365
+ _init_range = 0.1
366
+
367
+ def __init__(self, config: LookingGlassConfig):
368
+ super().__init__()
369
+ self.config = config
370
+ self.hidden_size = config.hidden_size
371
+ self.intermediate_size = config.intermediate_size
372
+ self.num_layers = config.num_hidden_layers
373
+ self.num_directions = 2 if config.bidirectional else 1
374
+ self._batch_size = 1
375
+
376
+ # Embedding
377
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size,
378
+ padding_idx=config.pad_token_id)
379
+ self.embed_tokens.weight.data.uniform_(-self._init_range, self._init_range)
380
+ self.embed_dropout = _EmbeddingDropout(self.embed_tokens, config.embed_dropout)
381
+
382
+ # LSTM layers
383
+ self.layers = nn.ModuleList()
384
+ for i in range(config.num_hidden_layers):
385
+ input_size = config.hidden_size if i == 0 else config.intermediate_size
386
+ output_size = (config.intermediate_size if i != config.num_hidden_layers - 1
387
+ else config.hidden_size) // self.num_directions
388
+ lstm = nn.LSTM(input_size, output_size, num_layers=1,
389
+ batch_first=True, bidirectional=config.bidirectional)
390
+ self.layers.append(_WeightDropout(lstm, config.weight_dropout))
391
+
392
+ # Dropout
393
+ self.input_dropout = _RNNDropout(config.input_dropout)
394
+ self.hidden_dropout = nn.ModuleList([
395
+ _RNNDropout(config.hidden_dropout) for _ in range(config.num_hidden_layers)
396
+ ])
397
+
398
+ self._hidden_state = None
399
+ self.reset()
400
+
401
+ def reset(self):
402
+ """Reset LSTM hidden states."""
403
+ self._hidden_state = [self._init_hidden(i) for i in range(self.num_layers)]
404
+
405
+ def _init_hidden(self, layer_idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
406
+ nh = (self.intermediate_size if layer_idx != self.num_layers - 1
407
+ else self.hidden_size) // self.num_directions
408
+ weight = next(self.parameters())
409
+ return (weight.new_zeros(self.num_directions, self._batch_size, nh),
410
+ weight.new_zeros(self.num_directions, self._batch_size, nh))
411
+
412
+ def _resize_hidden(self, batch_size: int):
413
+ new_hidden = []
414
+ for i in range(self.num_layers):
415
+ nh = (self.intermediate_size if i != self.num_layers - 1
416
+ else self.hidden_size) // self.num_directions
417
+ h, c = self._hidden_state[i]
418
+
419
+ if self._batch_size < batch_size:
420
+ h = torch.cat([h, h.new_zeros(self.num_directions, batch_size - self._batch_size, nh)], dim=1)
421
+ c = torch.cat([c, c.new_zeros(self.num_directions, batch_size - self._batch_size, nh)], dim=1)
422
+ elif self._batch_size > batch_size:
423
+ h = h[:, :batch_size].contiguous()
424
+ c = c[:, :batch_size].contiguous()
425
+ new_hidden.append((h, c))
426
+
427
+ self._hidden_state = new_hidden
428
+ self._batch_size = batch_size
429
+
430
+ def forward(self, input_ids: torch.LongTensor) -> torch.Tensor:
431
+ """Returns hidden states for all positions: (batch, seq_len, hidden_size)"""
432
+ batch_size, seq_len = input_ids.shape
433
+
434
+ if batch_size != self._batch_size:
435
+ self._resize_hidden(batch_size)
436
+
437
+ hidden = self.input_dropout(self.embed_dropout(input_ids))
438
+
439
+ new_hidden = []
440
+ for i, (layer, hdp) in enumerate(zip(self.layers, self.hidden_dropout)):
441
+ hidden, h = layer(hidden, self._hidden_state[i])
442
+ new_hidden.append(h)
443
+ if i != self.num_layers - 1:
444
+ hidden = hdp(hidden)
445
+
446
+ self._hidden_state = [(h.detach(), c.detach()) for h, c in new_hidden]
447
+ return hidden
448
+
449
+
450
+ class _LMHead(nn.Module):
451
+ """Language modeling head."""
452
+
453
+ _init_range = 0.1
454
+
455
+ def __init__(self, config: LookingGlassConfig, embed_tokens: Optional[nn.Embedding] = None):
456
+ super().__init__()
457
+ self.output_dropout = _RNNDropout(config.output_dropout)
458
+ self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=config.output_bias)
459
+ self.decoder.weight.data.uniform_(-self._init_range, self._init_range)
460
+
461
+ if config.output_bias:
462
+ self.decoder.bias.data.zero_()
463
+
464
+ if embed_tokens is not None and config.tie_weights:
465
+ self.decoder.weight = embed_tokens.weight
466
+
467
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
468
+ return self.decoder(self.output_dropout(hidden_states))
469
+
470
+
471
+ # =============================================================================
472
+ # Models
473
+ # =============================================================================
474
+
475
+ class LookingGlass(nn.Module):
476
+ """
477
+ LookingGlass encoder model.
478
+
479
+ Outputs sequence embeddings for downstream tasks (classification, clustering, etc.).
480
+ Uses last-token embedding by default, matching original LookingGlass.
481
+
482
+ Example:
483
+ >>> model = LookingGlass.from_pretrained('lookingglass-v1')
484
+ >>> tokenizer = LookingGlassTokenizer()
485
+ >>> inputs = tokenizer("GATTACA", return_tensors=True)
486
+ >>> embeddings = model.get_embeddings(inputs['input_ids']) # (1, 104)
487
+ """
488
+
489
+ config_class = LookingGlassConfig
490
+
491
+ def __init__(self, config: Optional[LookingGlassConfig] = None):
492
+ super().__init__()
493
+ self.config = config or LookingGlassConfig()
494
+ self.encoder = _AWDLSTMEncoder(self.config)
495
+
496
+ def reset(self):
497
+ """Reset hidden states."""
498
+ self.encoder.reset()
499
+
500
+ def forward(self, input_ids: torch.LongTensor, **kwargs) -> torch.Tensor:
501
+ """
502
+ Forward pass. Returns last-token embeddings.
503
+
504
+ Args:
505
+ input_ids: Token indices (batch, seq_len)
506
+
507
+ Returns:
508
+ Embeddings (batch, hidden_size)
509
+ """
510
+ return self.get_embeddings(input_ids)
511
+
512
+ def get_embeddings(self, input_ids: torch.LongTensor) -> torch.Tensor:
513
+ """
514
+ Get sequence embeddings using last-token pooling (original LG method).
515
+
516
+ Resets hidden state before encoding for deterministic results.
517
+
518
+ Args:
519
+ input_ids: Token indices (batch, seq_len)
520
+
521
+ Returns:
522
+ Embeddings (batch, hidden_size)
523
+ """
524
+ self.encoder.reset()
525
+ hidden = self.encoder(input_ids) # (batch, seq_len, hidden_size)
526
+ return hidden[:, -1] # last token
527
+
528
+ def get_hidden_states(self, input_ids: torch.LongTensor) -> torch.Tensor:
529
+ """
530
+ Get hidden states for all positions.
531
+
532
+ Resets hidden state before encoding for deterministic results.
533
+
534
+ Args:
535
+ input_ids: Token indices (batch, seq_len)
536
+
537
+ Returns:
538
+ Hidden states (batch, seq_len, hidden_size)
539
+ """
540
+ self.encoder.reset()
541
+ return self.encoder(input_ids)
542
+
543
+ def save_pretrained(self, save_directory: str):
544
+ os.makedirs(save_directory, exist_ok=True)
545
+ self.config.save_pretrained(save_directory)
546
+ torch.save(self.state_dict(), os.path.join(save_directory, "pytorch_model.bin"))
547
+
548
+ @classmethod
549
+ def from_pretrained(cls, pretrained_path: str, config: Optional[LookingGlassConfig] = None) -> "LookingGlass":
550
+ config = config or LookingGlassConfig.from_pretrained(pretrained_path)
551
+ model = cls(config)
552
+
553
+ if _is_hf_hub_id(pretrained_path):
554
+ model_path = _download_from_hub(pretrained_path, "pytorch_model.bin")
555
+ else:
556
+ model_path = os.path.join(pretrained_path, "pytorch_model.bin")
557
+
558
+ if os.path.exists(model_path):
559
+ state_dict = torch.load(model_path, map_location='cpu')
560
+ # Only load encoder weights
561
+ encoder_state_dict = {k: v for k, v in state_dict.items()
562
+ if not k.startswith('lm_head.')}
563
+ model.load_state_dict(encoder_state_dict, strict=False)
564
+
565
+ return model
566
+
567
+
568
+ class LookingGlassLM(nn.Module):
569
+ """
570
+ LookingGlass with language modeling head.
571
+
572
+ Full model for next-token prediction. Can also extract embeddings.
573
+
574
+ Example:
575
+ >>> model = LookingGlassLM.from_pretrained('lookingglass-v1')
576
+ >>> tokenizer = LookingGlassTokenizer()
577
+ >>> inputs = tokenizer("GATTACA", return_tensors=True)
578
+ >>> logits = model(inputs['input_ids']) # (1, 8, 8)
579
+ >>> embeddings = model.get_embeddings(inputs['input_ids']) # (1, 104)
580
+ """
581
+
582
+ config_class = LookingGlassConfig
583
+
584
+ def __init__(self, config: Optional[LookingGlassConfig] = None):
585
+ super().__init__()
586
+ self.config = config or LookingGlassConfig()
587
+ self.encoder = _AWDLSTMEncoder(self.config)
588
+ self.lm_head = _LMHead(
589
+ self.config,
590
+ embed_tokens=self.encoder.embed_tokens if self.config.tie_weights else None
591
+ )
592
+
593
+ def reset(self):
594
+ """Reset hidden states."""
595
+ self.encoder.reset()
596
+
597
+ def forward(self, input_ids: torch.LongTensor, **kwargs) -> torch.Tensor:
598
+ """
599
+ Forward pass. Returns logits for next-token prediction.
600
+
601
+ Args:
602
+ input_ids: Token indices (batch, seq_len)
603
+
604
+ Returns:
605
+ Logits (batch, seq_len, vocab_size)
606
+ """
607
+ hidden = self.encoder(input_ids)
608
+ return self.lm_head(hidden)
609
+
610
+ def get_embeddings(self, input_ids: torch.LongTensor) -> torch.Tensor:
611
+ """
612
+ Get sequence embeddings using last-token pooling.
613
+
614
+ Resets hidden state before encoding for deterministic results.
615
+
616
+ Args:
617
+ input_ids: Token indices (batch, seq_len)
618
+
619
+ Returns:
620
+ Embeddings (batch, hidden_size)
621
+ """
622
+ self.encoder.reset()
623
+ hidden = self.encoder(input_ids)
624
+ return hidden[:, -1]
625
+
626
+ def get_hidden_states(self, input_ids: torch.LongTensor) -> torch.Tensor:
627
+ """
628
+ Get hidden states for all positions.
629
+
630
+ Resets hidden state before encoding for deterministic results.
631
+
632
+ Args:
633
+ input_ids: Token indices (batch, seq_len)
634
+
635
+ Returns:
636
+ Hidden states (batch, seq_len, hidden_size)
637
+ """
638
+ self.encoder.reset()
639
+ return self.encoder(input_ids)
640
+
641
+ def save_pretrained(self, save_directory: str):
642
+ os.makedirs(save_directory, exist_ok=True)
643
+ self.config.save_pretrained(save_directory)
644
+ torch.save(self.state_dict(), os.path.join(save_directory, "pytorch_model.bin"))
645
+
646
+ @classmethod
647
+ def from_pretrained(cls, pretrained_path: str, config: Optional[LookingGlassConfig] = None) -> "LookingGlassLM":
648
+ config = config or LookingGlassConfig.from_pretrained(pretrained_path)
649
+ model = cls(config)
650
+
651
+ if _is_hf_hub_id(pretrained_path):
652
+ model_path = _download_from_hub(pretrained_path, "pytorch_model.bin")
653
+ else:
654
+ model_path = os.path.join(pretrained_path, "pytorch_model.bin")
655
+
656
+ if os.path.exists(model_path):
657
+ state_dict = torch.load(model_path, map_location='cpu')
658
+ model.load_state_dict(state_dict, strict=False)
659
+
660
+ return model
661
+
662
+
663
+ # =============================================================================
664
+ # Weight Loading
665
+ # =============================================================================
666
+
667
+ def load_original_weights(model: Union[LookingGlass, LookingGlassLM], weights_path: str) -> None:
668
+ """
669
+ Load weights from original fastai-trained LookingGlass checkpoint.
670
+
671
+ Args:
672
+ model: Model to load weights into
673
+ weights_path: Path to LookingGlass.pth or LookingGlass_enc.pth
674
+ """
675
+ checkpoint = torch.load(weights_path, map_location='cpu')
676
+
677
+ if 'model' in checkpoint:
678
+ state_dict = checkpoint['model']
679
+ else:
680
+ state_dict = checkpoint
681
+
682
+ is_lm_model = isinstance(model, LookingGlassLM)
683
+
684
+ new_state_dict = {}
685
+ for k, v in state_dict.items():
686
+ if '.module.weight_hh_l0' in k:
687
+ continue
688
+
689
+ if k.startswith('0.'):
690
+ new_k = k[2:]
691
+ new_k = new_k.replace('encoder.', 'embed_tokens.')
692
+ new_k = new_k.replace('encoder_dp.emb.', 'embed_tokens.')
693
+ new_k = new_k.replace('rnns.', 'layers.')
694
+ new_k = new_k.replace('hidden_dps.', 'hidden_dropout.')
695
+ new_k = new_k.replace('input_dp.', 'input_dropout.')
696
+ new_state_dict['encoder.' + new_k] = v
697
+
698
+ elif k.startswith('1.') and is_lm_model:
699
+ new_k = k[2:]
700
+ new_k = new_k.replace('output_dp.', 'output_dropout.')
701
+ new_state_dict['lm_head.' + new_k] = v
702
+
703
+ else:
704
+ new_k = k.replace('encoder.', 'embed_tokens.')
705
+ new_k = new_k.replace('encoder_dp.emb.', 'embed_tokens.')
706
+ new_k = new_k.replace('rnns.', 'layers.')
707
+ new_k = new_k.replace('hidden_dps.', 'hidden_dropout.')
708
+ new_k = new_k.replace('input_dp.', 'input_dropout.')
709
+ new_state_dict['encoder.' + new_k] = v
710
+
711
+ model.load_state_dict(new_state_dict, strict=False)
712
+
713
+
714
+ def convert_checkpoint(input_path: str, output_dir: str) -> None:
715
+ """Convert original checkpoint to new format."""
716
+ config = LookingGlassConfig()
717
+ model = LookingGlassLM(config)
718
+ load_original_weights(model, input_path)
719
+ model.save_pretrained(output_dir)
720
+
721
+ tokenizer = LookingGlassTokenizer()
722
+ tokenizer.save_pretrained(output_dir)
723
+ print(f"Saved to {output_dir}")
724
+
725
+
726
+ # =============================================================================
727
+ # CLI
728
+ # =============================================================================
729
+
730
+ if __name__ == '__main__':
731
+ import argparse
732
+
733
+ parser = argparse.ArgumentParser(description='LookingGlass DNA Language Model')
734
+ parser.add_argument('--convert', type=str, help='Convert original weights')
735
+ parser.add_argument('--output', type=str, default='./lookingglass-v1', help='Output directory')
736
+ parser.add_argument('--test', action='store_true', help='Run tests')
737
+ args = parser.parse_args()
738
+
739
+ if args.convert:
740
+ convert_checkpoint(args.convert, args.output)
741
+
742
+ elif args.test:
743
+ print("Testing LookingGlass...\n")
744
+
745
+ tokenizer = LookingGlassTokenizer()
746
+ print(f"Vocab: {tokenizer.vocab}")
747
+ print(f"BOS token added: {tokenizer.add_bos_token}")
748
+ print(f"EOS token added: {tokenizer.add_eos_token}")
749
+
750
+ inputs = tokenizer("GATTACA", return_tensors=True)
751
+ print(f"\nTokenized 'GATTACA': {inputs['input_ids']}")
752
+ print(f"Decoded: {tokenizer.decode(inputs['input_ids'][0])}")
753
+
754
+ config = LookingGlassConfig()
755
+ print(f"\nConfig: bidirectional={config.bidirectional}")
756
+
757
+ # Test LookingGlass (encoder)
758
+ encoder = LookingGlass(config)
759
+ print(f"\nLookingGlass params: {sum(p.numel() for p in encoder.parameters()):,}")
760
+
761
+ encoder.eval()
762
+ with torch.no_grad():
763
+ emb = encoder.get_embeddings(inputs['input_ids'])
764
+ print(f"Embeddings shape: {emb.shape}")
765
+
766
+ # Test LookingGlassLM
767
+ lm = LookingGlassLM(config)
768
+ print(f"\nLookingGlassLM params: {sum(p.numel() for p in lm.parameters()):,}")
769
+
770
+ lm.eval()
771
+ with torch.no_grad():
772
+ logits = lm(inputs['input_ids'])
773
+ emb = lm.get_embeddings(inputs['input_ids'])
774
+ print(f"Logits shape: {logits.shape}")
775
+ print(f"Embeddings shape: {emb.shape}")
776
+
777
+ print("\nAll tests passed!")
778
+
779
+ else:
780
+ parser.print_help()
lookingglass_classifier.py ADDED
@@ -0,0 +1,258 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ LookingGlass Classifiers - Fine-tuned DNA sequence classifiers
3
+
4
+ Pure PyTorch implementation of LookingGlass classifiers from the paper.
5
+ Uses LookingGlass encoder with classification head.
6
+
7
+ Usage:
8
+ from lookingglass_classifier import LookingGlassClassifier, LookingGlassTokenizer
9
+
10
+ model = LookingGlassClassifier.from_pretrained('.')
11
+ tokenizer = LookingGlassTokenizer()
12
+
13
+ inputs = tokenizer(["GATTACA"], return_tensors=True)
14
+ logits = model(inputs['input_ids']) # (batch, num_classes)
15
+ predictions = logits.argmax(dim=-1)
16
+ """
17
+
18
+ import json
19
+ import os
20
+ from dataclasses import dataclass, asdict, field
21
+ from typing import Optional, List
22
+
23
+ import torch
24
+ import torch.nn as nn
25
+
26
+ from lookingglass import (
27
+ LookingGlassConfig,
28
+ LookingGlassTokenizer,
29
+ _AWDLSTMEncoder,
30
+ _is_hf_hub_id,
31
+ _download_from_hub,
32
+ )
33
+
34
+ __version__ = "1.1.0"
35
+ __all__ = ["LookingGlassClassifierConfig", "LookingGlassClassifier", "LookingGlassTokenizer"]
36
+
37
+
38
+ @dataclass
39
+ class LookingGlassClassifierConfig(LookingGlassConfig):
40
+ """Configuration for LookingGlass classifier."""
41
+ num_classes: int = 2
42
+ classifier_hidden: int = 50
43
+ classifier_dropout: float = 0.0
44
+ class_names: List[str] = field(default_factory=list)
45
+
46
+ def save_pretrained(self, save_directory: str):
47
+ os.makedirs(save_directory, exist_ok=True)
48
+ with open(os.path.join(save_directory, "config.json"), 'w') as f:
49
+ json.dump(self.to_dict(), f, indent=2)
50
+
51
+ @classmethod
52
+ def from_pretrained(cls, pretrained_path: str) -> "LookingGlassClassifierConfig":
53
+ if _is_hf_hub_id(pretrained_path):
54
+ try:
55
+ config_path = _download_from_hub(pretrained_path, "config.json")
56
+ except Exception:
57
+ return cls()
58
+ elif os.path.isdir(pretrained_path):
59
+ config_path = os.path.join(pretrained_path, "config.json")
60
+ else:
61
+ config_path = pretrained_path
62
+
63
+ if os.path.exists(config_path):
64
+ with open(config_path, 'r') as f:
65
+ config_dict = json.load(f)
66
+ valid_fields = {f.name for f in cls.__dataclass_fields__.values()}
67
+ return cls(**{k: v for k, v in config_dict.items() if k in valid_fields})
68
+ return cls()
69
+
70
+
71
+ class LookingGlassClassifier(nn.Module):
72
+ """
73
+ LookingGlass with classification head.
74
+
75
+ Uses concat pooling (max + mean + last) followed by classification layers.
76
+
77
+ Example:
78
+ >>> model = LookingGlassClassifier.from_pretrained('.')
79
+ >>> tokenizer = LookingGlassTokenizer()
80
+ >>> inputs = tokenizer("GATTACA", return_tensors=True)
81
+ >>> logits = model(inputs['input_ids']) # (1, num_classes)
82
+ >>> prediction = logits.argmax(dim=-1)
83
+ """
84
+
85
+ def __init__(self, config: Optional[LookingGlassClassifierConfig] = None):
86
+ super().__init__()
87
+ self.config = config or LookingGlassClassifierConfig()
88
+ self.encoder = _AWDLSTMEncoder(self.config)
89
+
90
+ # Concat pooling: max + mean + last = 3 * hidden_size
91
+ pooled_size = 3 * self.config.hidden_size
92
+
93
+ # Classification head: BatchNorm -> Linear -> ReLU -> BatchNorm -> Linear
94
+ self.classifier = nn.Sequential(
95
+ nn.BatchNorm1d(pooled_size),
96
+ nn.Dropout(self.config.classifier_dropout),
97
+ nn.Linear(pooled_size, self.config.classifier_hidden),
98
+ nn.ReLU(),
99
+ nn.BatchNorm1d(self.config.classifier_hidden),
100
+ nn.Dropout(self.config.classifier_dropout),
101
+ nn.Linear(self.config.classifier_hidden, self.config.num_classes),
102
+ )
103
+
104
+ def forward(self, input_ids: torch.LongTensor) -> torch.Tensor:
105
+ """
106
+ Forward pass returning classification logits.
107
+
108
+ Args:
109
+ input_ids: Token indices (batch, seq_len)
110
+
111
+ Returns:
112
+ Logits (batch, num_classes)
113
+ """
114
+ self.encoder.reset()
115
+ hidden = self.encoder(input_ids) # (batch, seq_len, hidden_size)
116
+
117
+ # Concat pooling: max, mean, last
118
+ max_pool = hidden.max(dim=1).values
119
+ mean_pool = hidden.mean(dim=1)
120
+ last_pool = hidden[:, -1]
121
+ pooled = torch.cat([max_pool, mean_pool, last_pool], dim=-1)
122
+
123
+ return self.classifier(pooled)
124
+
125
+ def predict(self, input_ids: torch.LongTensor) -> torch.Tensor:
126
+ """Return predicted class indices."""
127
+ logits = self.forward(input_ids)
128
+ return logits.argmax(dim=-1)
129
+
130
+ def predict_proba(self, input_ids: torch.LongTensor) -> torch.Tensor:
131
+ """Return class probabilities."""
132
+ logits = self.forward(input_ids)
133
+ return torch.softmax(logits, dim=-1)
134
+
135
+ def get_embeddings(self, input_ids: torch.LongTensor) -> torch.Tensor:
136
+ """Get sequence embeddings (last token) from encoder."""
137
+ self.encoder.reset()
138
+ hidden = self.encoder(input_ids)
139
+ return hidden[:, -1]
140
+
141
+ def save_pretrained(self, save_directory: str):
142
+ os.makedirs(save_directory, exist_ok=True)
143
+ self.config.save_pretrained(save_directory)
144
+ torch.save(self.state_dict(), os.path.join(save_directory, "pytorch_model.bin"))
145
+
146
+ @classmethod
147
+ def from_pretrained(
148
+ cls, pretrained_path: str, config: Optional[LookingGlassClassifierConfig] = None
149
+ ) -> "LookingGlassClassifier":
150
+ config = config or LookingGlassClassifierConfig.from_pretrained(pretrained_path)
151
+ model = cls(config)
152
+
153
+ if _is_hf_hub_id(pretrained_path):
154
+ model_path = _download_from_hub(pretrained_path, "pytorch_model.bin")
155
+ else:
156
+ model_path = os.path.join(pretrained_path, "pytorch_model.bin")
157
+
158
+ if os.path.exists(model_path):
159
+ state_dict = torch.load(model_path, map_location='cpu')
160
+ model.load_state_dict(state_dict, strict=False)
161
+
162
+ return model
163
+
164
+
165
+ def convert_classifier_weights(
166
+ original_path: str,
167
+ output_dir: str,
168
+ num_classes: int,
169
+ class_names: Optional[List[str]] = None,
170
+ ):
171
+ """
172
+ Convert original fastai classifier weights to pure PyTorch format.
173
+
174
+ Args:
175
+ original_path: Path to original .pth file
176
+ output_dir: Output directory for converted model
177
+ num_classes: Number of output classes
178
+ class_names: Optional list of class names
179
+ """
180
+ print(f"Loading weights from {original_path}...")
181
+ original = torch.load(original_path, map_location='cpu')
182
+ if 'model' in original:
183
+ original = original['model']
184
+
185
+ # Create config
186
+ config = LookingGlassClassifierConfig(
187
+ num_classes=num_classes,
188
+ classifier_hidden=50,
189
+ class_names=class_names or [],
190
+ )
191
+
192
+ # Create model
193
+ model = LookingGlassClassifier(config)
194
+
195
+ # Map weights
196
+ new_state = {}
197
+
198
+ # Encoder weights
199
+ weight_map = {
200
+ '0.module.encoder.weight': 'encoder.embed_tokens.weight',
201
+ '0.module.encoder_dp.emb.weight': 'encoder.embed_dropout.embedding.weight',
202
+ }
203
+
204
+ for i in range(3):
205
+ weight_map.update({
206
+ f'0.module.rnns.{i}.weight_hh_l0_raw': f'encoder.layers.{i}.weight_hh_l0_raw',
207
+ f'0.module.rnns.{i}.module.weight_ih_l0': f'encoder.layers.{i}.module.weight_ih_l0',
208
+ f'0.module.rnns.{i}.module.weight_hh_l0': f'encoder.layers.{i}.module.weight_hh_l0',
209
+ f'0.module.rnns.{i}.module.bias_ih_l0': f'encoder.layers.{i}.module.bias_ih_l0',
210
+ f'0.module.rnns.{i}.module.bias_hh_l0': f'encoder.layers.{i}.module.bias_hh_l0',
211
+ })
212
+
213
+ # Classifier head weights
214
+ # Original: 1.layers.{0,2,4,6} -> our Sequential indices
215
+ classifier_map = {
216
+ '1.layers.0.weight': 'classifier.0.weight',
217
+ '1.layers.0.bias': 'classifier.0.bias',
218
+ '1.layers.0.running_mean': 'classifier.0.running_mean',
219
+ '1.layers.0.running_var': 'classifier.0.running_var',
220
+ '1.layers.0.num_batches_tracked': 'classifier.0.num_batches_tracked',
221
+ '1.layers.2.weight': 'classifier.2.weight',
222
+ '1.layers.2.bias': 'classifier.2.bias',
223
+ '1.layers.4.weight': 'classifier.4.weight',
224
+ '1.layers.4.bias': 'classifier.4.bias',
225
+ '1.layers.4.running_mean': 'classifier.4.running_mean',
226
+ '1.layers.4.running_var': 'classifier.4.running_var',
227
+ '1.layers.4.num_batches_tracked': 'classifier.4.num_batches_tracked',
228
+ '1.layers.6.weight': 'classifier.6.weight',
229
+ '1.layers.6.bias': 'classifier.6.bias',
230
+ }
231
+ weight_map.update(classifier_map)
232
+
233
+ for old_key, new_key in weight_map.items():
234
+ if old_key in original:
235
+ new_state[new_key] = original[old_key]
236
+
237
+ # Load and save
238
+ model.load_state_dict(new_state, strict=False)
239
+
240
+ os.makedirs(output_dir, exist_ok=True)
241
+ config.save_pretrained(output_dir)
242
+ torch.save(model.state_dict(), os.path.join(output_dir, "pytorch_model.bin"))
243
+
244
+ print(f"Saved to {output_dir}")
245
+ return model
246
+
247
+
248
+ if __name__ == "__main__":
249
+ import argparse
250
+
251
+ parser = argparse.ArgumentParser(description="Convert LookingGlass classifier weights")
252
+ parser.add_argument("--input", required=True, help="Path to original .pth file")
253
+ parser.add_argument("--output", required=True, help="Output directory")
254
+ parser.add_argument("--num-classes", type=int, required=True, help="Number of classes")
255
+ parser.add_argument("--class-names", nargs="+", help="Class names")
256
+
257
+ args = parser.parse_args()
258
+ convert_classifier_weights(args.input, args.output, args.num_classes, args.class_names)
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:31658a22f183d0c1d3f823980e1c1f32e3d6a1a9992a35d52251b61a1e439d85
3
+ size 67862763