|
|
""" |
|
|
LookingGlass Classifiers - Fine-tuned DNA sequence classifiers |
|
|
|
|
|
Pure PyTorch implementation of LookingGlass classifiers from the paper. |
|
|
Uses LookingGlass encoder with classification head. |
|
|
|
|
|
Usage: |
|
|
from lookingglass_classifier import LookingGlassClassifier, LookingGlassTokenizer |
|
|
|
|
|
model = LookingGlassClassifier.from_pretrained('.') |
|
|
tokenizer = LookingGlassTokenizer() |
|
|
|
|
|
inputs = tokenizer(["GATTACA"], return_tensors=True) |
|
|
logits = model(inputs['input_ids']) # (batch, num_classes) |
|
|
predictions = logits.argmax(dim=-1) |
|
|
""" |
|
|
|
|
|
import json |
|
|
import os |
|
|
from dataclasses import dataclass, asdict, field |
|
|
from typing import Optional, List |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
|
|
|
from lookingglass import ( |
|
|
LookingGlassConfig, |
|
|
LookingGlassTokenizer, |
|
|
_AWDLSTMEncoder, |
|
|
_is_hf_hub_id, |
|
|
_download_from_hub, |
|
|
) |
|
|
|
|
|
__version__ = "1.1.0" |
|
|
__all__ = ["LookingGlassClassifierConfig", "LookingGlassClassifier", "LookingGlassTokenizer"] |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class LookingGlassClassifierConfig(LookingGlassConfig): |
|
|
"""Configuration for LookingGlass classifier.""" |
|
|
num_classes: int = 2 |
|
|
classifier_hidden: int = 50 |
|
|
classifier_dropout: float = 0.0 |
|
|
class_names: List[str] = field(default_factory=list) |
|
|
|
|
|
def save_pretrained(self, save_directory: str): |
|
|
os.makedirs(save_directory, exist_ok=True) |
|
|
with open(os.path.join(save_directory, "config.json"), 'w') as f: |
|
|
json.dump(self.to_dict(), f, indent=2) |
|
|
|
|
|
@classmethod |
|
|
def from_pretrained(cls, pretrained_path: str) -> "LookingGlassClassifierConfig": |
|
|
if _is_hf_hub_id(pretrained_path): |
|
|
try: |
|
|
config_path = _download_from_hub(pretrained_path, "config.json") |
|
|
except Exception: |
|
|
return cls() |
|
|
elif os.path.isdir(pretrained_path): |
|
|
config_path = os.path.join(pretrained_path, "config.json") |
|
|
else: |
|
|
config_path = pretrained_path |
|
|
|
|
|
if os.path.exists(config_path): |
|
|
with open(config_path, 'r') as f: |
|
|
config_dict = json.load(f) |
|
|
valid_fields = {f.name for f in cls.__dataclass_fields__.values()} |
|
|
return cls(**{k: v for k, v in config_dict.items() if k in valid_fields}) |
|
|
return cls() |
|
|
|
|
|
|
|
|
class LookingGlassClassifier(nn.Module): |
|
|
""" |
|
|
LookingGlass with classification head. |
|
|
|
|
|
Uses concat pooling (max + mean + last) followed by classification layers. |
|
|
|
|
|
Example: |
|
|
>>> model = LookingGlassClassifier.from_pretrained('.') |
|
|
>>> tokenizer = LookingGlassTokenizer() |
|
|
>>> inputs = tokenizer("GATTACA", return_tensors=True) |
|
|
>>> logits = model(inputs['input_ids']) # (1, num_classes) |
|
|
>>> prediction = logits.argmax(dim=-1) |
|
|
""" |
|
|
|
|
|
def __init__(self, config: Optional[LookingGlassClassifierConfig] = None): |
|
|
super().__init__() |
|
|
self.config = config or LookingGlassClassifierConfig() |
|
|
self.encoder = _AWDLSTMEncoder(self.config) |
|
|
|
|
|
|
|
|
pooled_size = 3 * self.config.hidden_size |
|
|
|
|
|
|
|
|
self.classifier = nn.Sequential( |
|
|
nn.BatchNorm1d(pooled_size), |
|
|
nn.Dropout(self.config.classifier_dropout), |
|
|
nn.Linear(pooled_size, self.config.classifier_hidden), |
|
|
nn.ReLU(), |
|
|
nn.BatchNorm1d(self.config.classifier_hidden), |
|
|
nn.Dropout(self.config.classifier_dropout), |
|
|
nn.Linear(self.config.classifier_hidden, self.config.num_classes), |
|
|
) |
|
|
|
|
|
def forward(self, input_ids: torch.LongTensor) -> torch.Tensor: |
|
|
""" |
|
|
Forward pass returning classification logits. |
|
|
|
|
|
Args: |
|
|
input_ids: Token indices (batch, seq_len) |
|
|
|
|
|
Returns: |
|
|
Logits (batch, num_classes) |
|
|
""" |
|
|
self.encoder.reset() |
|
|
hidden = self.encoder(input_ids) |
|
|
|
|
|
|
|
|
max_pool = hidden.max(dim=1).values |
|
|
mean_pool = hidden.mean(dim=1) |
|
|
last_pool = hidden[:, -1] |
|
|
pooled = torch.cat([max_pool, mean_pool, last_pool], dim=-1) |
|
|
|
|
|
return self.classifier(pooled) |
|
|
|
|
|
def predict(self, input_ids: torch.LongTensor) -> torch.Tensor: |
|
|
"""Return predicted class indices.""" |
|
|
logits = self.forward(input_ids) |
|
|
return logits.argmax(dim=-1) |
|
|
|
|
|
def predict_proba(self, input_ids: torch.LongTensor) -> torch.Tensor: |
|
|
"""Return class probabilities.""" |
|
|
logits = self.forward(input_ids) |
|
|
return torch.softmax(logits, dim=-1) |
|
|
|
|
|
def get_embeddings(self, input_ids: torch.LongTensor) -> torch.Tensor: |
|
|
"""Get sequence embeddings (last token) from encoder.""" |
|
|
self.encoder.reset() |
|
|
hidden = self.encoder(input_ids) |
|
|
return hidden[:, -1] |
|
|
|
|
|
def save_pretrained(self, save_directory: str): |
|
|
os.makedirs(save_directory, exist_ok=True) |
|
|
self.config.save_pretrained(save_directory) |
|
|
torch.save(self.state_dict(), os.path.join(save_directory, "pytorch_model.bin")) |
|
|
|
|
|
@classmethod |
|
|
def from_pretrained( |
|
|
cls, pretrained_path: str, config: Optional[LookingGlassClassifierConfig] = None |
|
|
) -> "LookingGlassClassifier": |
|
|
config = config or LookingGlassClassifierConfig.from_pretrained(pretrained_path) |
|
|
model = cls(config) |
|
|
|
|
|
if _is_hf_hub_id(pretrained_path): |
|
|
model_path = _download_from_hub(pretrained_path, "pytorch_model.bin") |
|
|
else: |
|
|
model_path = os.path.join(pretrained_path, "pytorch_model.bin") |
|
|
|
|
|
if os.path.exists(model_path): |
|
|
state_dict = torch.load(model_path, map_location='cpu') |
|
|
model.load_state_dict(state_dict, strict=False) |
|
|
|
|
|
return model |
|
|
|
|
|
|
|
|
def convert_classifier_weights( |
|
|
original_path: str, |
|
|
output_dir: str, |
|
|
num_classes: int, |
|
|
class_names: Optional[List[str]] = None, |
|
|
): |
|
|
""" |
|
|
Convert original fastai classifier weights to pure PyTorch format. |
|
|
|
|
|
Args: |
|
|
original_path: Path to original .pth file |
|
|
output_dir: Output directory for converted model |
|
|
num_classes: Number of output classes |
|
|
class_names: Optional list of class names |
|
|
""" |
|
|
print(f"Loading weights from {original_path}...") |
|
|
original = torch.load(original_path, map_location='cpu') |
|
|
if 'model' in original: |
|
|
original = original['model'] |
|
|
|
|
|
|
|
|
config = LookingGlassClassifierConfig( |
|
|
num_classes=num_classes, |
|
|
classifier_hidden=50, |
|
|
class_names=class_names or [], |
|
|
) |
|
|
|
|
|
|
|
|
model = LookingGlassClassifier(config) |
|
|
|
|
|
|
|
|
new_state = {} |
|
|
|
|
|
|
|
|
weight_map = { |
|
|
'0.module.encoder.weight': 'encoder.embed_tokens.weight', |
|
|
'0.module.encoder_dp.emb.weight': 'encoder.embed_dropout.embedding.weight', |
|
|
} |
|
|
|
|
|
for i in range(3): |
|
|
weight_map.update({ |
|
|
f'0.module.rnns.{i}.weight_hh_l0_raw': f'encoder.layers.{i}.weight_hh_l0_raw', |
|
|
f'0.module.rnns.{i}.module.weight_ih_l0': f'encoder.layers.{i}.module.weight_ih_l0', |
|
|
f'0.module.rnns.{i}.module.weight_hh_l0': f'encoder.layers.{i}.module.weight_hh_l0', |
|
|
f'0.module.rnns.{i}.module.bias_ih_l0': f'encoder.layers.{i}.module.bias_ih_l0', |
|
|
f'0.module.rnns.{i}.module.bias_hh_l0': f'encoder.layers.{i}.module.bias_hh_l0', |
|
|
}) |
|
|
|
|
|
|
|
|
|
|
|
classifier_map = { |
|
|
'1.layers.0.weight': 'classifier.0.weight', |
|
|
'1.layers.0.bias': 'classifier.0.bias', |
|
|
'1.layers.0.running_mean': 'classifier.0.running_mean', |
|
|
'1.layers.0.running_var': 'classifier.0.running_var', |
|
|
'1.layers.0.num_batches_tracked': 'classifier.0.num_batches_tracked', |
|
|
'1.layers.2.weight': 'classifier.2.weight', |
|
|
'1.layers.2.bias': 'classifier.2.bias', |
|
|
'1.layers.4.weight': 'classifier.4.weight', |
|
|
'1.layers.4.bias': 'classifier.4.bias', |
|
|
'1.layers.4.running_mean': 'classifier.4.running_mean', |
|
|
'1.layers.4.running_var': 'classifier.4.running_var', |
|
|
'1.layers.4.num_batches_tracked': 'classifier.4.num_batches_tracked', |
|
|
'1.layers.6.weight': 'classifier.6.weight', |
|
|
'1.layers.6.bias': 'classifier.6.bias', |
|
|
} |
|
|
weight_map.update(classifier_map) |
|
|
|
|
|
for old_key, new_key in weight_map.items(): |
|
|
if old_key in original: |
|
|
new_state[new_key] = original[old_key] |
|
|
|
|
|
|
|
|
model.load_state_dict(new_state, strict=False) |
|
|
|
|
|
os.makedirs(output_dir, exist_ok=True) |
|
|
config.save_pretrained(output_dir) |
|
|
torch.save(model.state_dict(), os.path.join(output_dir, "pytorch_model.bin")) |
|
|
|
|
|
print(f"Saved to {output_dir}") |
|
|
return model |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
import argparse |
|
|
|
|
|
parser = argparse.ArgumentParser(description="Convert LookingGlass classifier weights") |
|
|
parser.add_argument("--input", required=True, help="Path to original .pth file") |
|
|
parser.add_argument("--output", required=True, help="Output directory") |
|
|
parser.add_argument("--num-classes", type=int, required=True, help="Number of classes") |
|
|
parser.add_argument("--class-names", nargs="+", help="Class names") |
|
|
|
|
|
args = parser.parse_args() |
|
|
convert_classifier_weights(args.input, args.output, args.num_classes, args.class_names) |
|
|
|