|
|
""" |
|
|
TinyByteCNN Model for Fiction vs Non-Fiction Classification |
|
|
""" |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
import numpy as np |
|
|
import unicodedata |
|
|
import re |
|
|
from typing import Union, List |
|
|
|
|
|
|
|
|
class SE(nn.Module): |
|
|
"""Squeeze-Excitation module""" |
|
|
def __init__(self, c, r=8): |
|
|
super().__init__() |
|
|
m = max(c // r, 4) |
|
|
self.fc1 = nn.Linear(c, m) |
|
|
self.fc2 = nn.Linear(m, c) |
|
|
|
|
|
def forward(self, x): |
|
|
|
|
|
s = x.mean(dim=-1) |
|
|
s = F.silu(self.fc1(s)) |
|
|
s = torch.sigmoid(self.fc2(s)) |
|
|
return x * s.unsqueeze(-1) |
|
|
|
|
|
|
|
|
class SepResBlock(nn.Module): |
|
|
"""Separable Residual Block with SE attention""" |
|
|
def __init__(self, c_in, c_out, k=7, stride=1, dilation=1, use_gn=False, se_ratio=8, drop=0.0): |
|
|
super().__init__() |
|
|
Norm = (lambda c: nn.GroupNorm(32, c)) if use_gn else nn.BatchNorm1d |
|
|
|
|
|
self.dw = nn.Conv1d(c_in, c_in, k, stride=stride, dilation=dilation, |
|
|
padding=((k-1)//2)*dilation, groups=c_in, bias=False) |
|
|
self.bn1 = Norm(c_in) |
|
|
self.pw = nn.Conv1d(c_in, c_out, 1, bias=False) |
|
|
self.bn2 = Norm(c_out) |
|
|
self.se = SE(c_out, se_ratio) |
|
|
self.drop = nn.Dropout(p=drop) |
|
|
|
|
|
self.proj = None |
|
|
if stride != 1 or c_in != c_out: |
|
|
self.proj = nn.Conv1d(c_in, c_out, 1, stride=stride, bias=False) |
|
|
|
|
|
def forward(self, x): |
|
|
y = self.dw(x) |
|
|
y = F.silu(self.bn1(y)) |
|
|
y = self.pw(y) |
|
|
y = self.bn2(y) |
|
|
y = self.se(y) |
|
|
if self.proj is not None: |
|
|
x = self.proj(x) |
|
|
y = self.drop(y) |
|
|
return F.silu(x + y) |
|
|
|
|
|
|
|
|
class TinyByteCNN(nn.Module): |
|
|
"""TinyByteCNN for Fiction vs Non-Fiction Classification""" |
|
|
|
|
|
def __init__(self, config=None): |
|
|
super().__init__() |
|
|
|
|
|
|
|
|
if config is None: |
|
|
config = type('Config', (), { |
|
|
'vocab_size': 256, |
|
|
'embed_dim': 32, |
|
|
'widths': [128, 192, 256, 320], |
|
|
'use_gn': False, |
|
|
'head_drop': 0.1, |
|
|
'stochastic_depth': 0.05 |
|
|
})() |
|
|
|
|
|
self.config = config |
|
|
|
|
|
|
|
|
self.embed = nn.Embedding(config.vocab_size, config.embed_dim) |
|
|
|
|
|
|
|
|
self.stem = nn.Conv1d(config.embed_dim, config.widths[0], 5, stride=2, padding=2, bias=False) |
|
|
self.bn0 = nn.BatchNorm1d(config.widths[0]) if not config.use_gn else nn.GroupNorm(32, config.widths[0]) |
|
|
|
|
|
|
|
|
cfg = [ |
|
|
(2, config.widths[0], [1, 2]), |
|
|
(2, config.widths[1], [1, 2]), |
|
|
(3, config.widths[2], [1, 2, 4]), |
|
|
(3, config.widths[3], [1, 2, 8]) |
|
|
] |
|
|
|
|
|
stages = [] |
|
|
c_prev = config.widths[0] |
|
|
for blocks, c, ds in cfg: |
|
|
for i in range(blocks): |
|
|
stride = 2 if i == 0 else 1 |
|
|
d = ds[i] |
|
|
stages.append(SepResBlock(c_prev, c, k=7, stride=stride, dilation=d, |
|
|
use_gn=config.use_gn, drop=config.stochastic_depth)) |
|
|
c_prev = c |
|
|
|
|
|
self.stages = nn.Sequential(*stages) |
|
|
|
|
|
|
|
|
self.head = nn.Sequential( |
|
|
nn.Dropout(p=config.head_drop), |
|
|
nn.Linear(2 * config.widths[-1], 1) |
|
|
) |
|
|
|
|
|
def forward(self, x_bytes): |
|
|
""" |
|
|
Args: |
|
|
x_bytes: [B, T] uint8 tensor of byte values |
|
|
Returns: |
|
|
logits: [B] tensor of binary classification logits |
|
|
""" |
|
|
x = self.embed(x_bytes.long()) |
|
|
x = x.transpose(1, 2).contiguous() |
|
|
x = F.silu(self.bn0(self.stem(x))) |
|
|
x = self.stages(x) |
|
|
|
|
|
|
|
|
avg = x.mean(dim=-1) |
|
|
mx = x.amax(dim=-1) |
|
|
feats = torch.cat([avg, mx], dim=1) |
|
|
|
|
|
logits = self.head(feats).squeeze(1) |
|
|
return logits |
|
|
|
|
|
@classmethod |
|
|
def from_pretrained(cls, path_or_repo, use_safetensors=True): |
|
|
"""Load pretrained model (supports both .bin and .safetensors)""" |
|
|
import os |
|
|
from pathlib import Path |
|
|
|
|
|
|
|
|
if os.path.isdir(path_or_repo): |
|
|
|
|
|
base_path = Path(path_or_repo) |
|
|
safetensors_path = base_path / "model.safetensors" |
|
|
pytorch_path = base_path / "pytorch_model.bin" |
|
|
|
|
|
if use_safetensors and safetensors_path.exists(): |
|
|
|
|
|
from safetensors.torch import load_file |
|
|
state_dict = load_file(str(safetensors_path)) |
|
|
|
|
|
|
|
|
config_path = base_path / "config.json" |
|
|
if config_path.exists(): |
|
|
import json |
|
|
with open(config_path) as f: |
|
|
config_dict = json.load(f) |
|
|
config = type('Config', (), config_dict)() |
|
|
else: |
|
|
config = None |
|
|
|
|
|
model = cls(config) |
|
|
model.load_state_dict(state_dict) |
|
|
return model |
|
|
elif pytorch_path.exists(): |
|
|
checkpoint = torch.load(pytorch_path, weights_only=False, map_location='cpu') |
|
|
elif os.path.isfile(path_or_repo): |
|
|
if path_or_repo.endswith('.safetensors'): |
|
|
from safetensors.torch import load_file |
|
|
state_dict = load_file(path_or_repo) |
|
|
model = cls() |
|
|
model.load_state_dict(state_dict) |
|
|
return model |
|
|
else: |
|
|
checkpoint = torch.load(path_or_repo, weights_only=False, map_location='cpu') |
|
|
else: |
|
|
|
|
|
from huggingface_hub import hf_hub_download |
|
|
|
|
|
if use_safetensors: |
|
|
try: |
|
|
model_file = hf_hub_download(repo_id=path_or_repo, filename="model.safetensors") |
|
|
from safetensors.torch import load_file |
|
|
state_dict = load_file(model_file) |
|
|
model = cls() |
|
|
model.load_state_dict(state_dict) |
|
|
return model |
|
|
except: |
|
|
pass |
|
|
|
|
|
model_file = hf_hub_download(repo_id=path_or_repo, filename="pytorch_model.bin") |
|
|
checkpoint = torch.load(model_file, weights_only=False, map_location='cpu') |
|
|
|
|
|
|
|
|
if 'checkpoint' in locals(): |
|
|
config = checkpoint.get('config', None) |
|
|
model = cls(config) |
|
|
state_dict = checkpoint.get('model_state_dict', checkpoint) |
|
|
model.load_state_dict(state_dict) |
|
|
return model |
|
|
|
|
|
def save_pretrained(self, save_path): |
|
|
"""Save model to directory""" |
|
|
import os |
|
|
os.makedirs(save_path, exist_ok=True) |
|
|
|
|
|
torch.save({ |
|
|
'model_state_dict': self.state_dict(), |
|
|
'config': self.config |
|
|
}, os.path.join(save_path, 'pytorch_model.bin')) |
|
|
|
|
|
|
|
|
def preprocess_text(text: str, max_len: int = 4096) -> torch.Tensor: |
|
|
""" |
|
|
Preprocess text to bytes for model input |
|
|
|
|
|
Args: |
|
|
text: Input text string |
|
|
max_len: Maximum sequence length (default 4096) |
|
|
|
|
|
Returns: |
|
|
Tensor of shape [1, max_len] containing byte values |
|
|
""" |
|
|
|
|
|
text = unicodedata.normalize('NFC', text) |
|
|
|
|
|
|
|
|
text = text.replace('\r\n', '\n') |
|
|
|
|
|
|
|
|
text = re.sub(r'\s{3,}', ' ', text) |
|
|
|
|
|
|
|
|
text_bytes = text.encode('utf-8', errors='ignore') |
|
|
|
|
|
|
|
|
input_ids = np.zeros(max_len, dtype=np.uint8) |
|
|
input_ids[:min(len(text_bytes), max_len)] = list(text_bytes[:max_len]) |
|
|
|
|
|
return torch.from_numpy(input_ids).unsqueeze(0) |
|
|
|
|
|
|
|
|
def classify_text(text: Union[str, List[str]], model=None, device='cpu'): |
|
|
""" |
|
|
Classify text as fiction or non-fiction |
|
|
|
|
|
Args: |
|
|
text: Single string or list of strings to classify |
|
|
model: Pre-loaded model (optional) |
|
|
device: Device to run on ('cpu', 'cuda', 'mps') |
|
|
|
|
|
Returns: |
|
|
Dictionary with predictions and confidence scores |
|
|
""" |
|
|
if model is None: |
|
|
model = TinyByteCNN.from_pretrained("fiction_classifier_hf") |
|
|
|
|
|
model = model.to(device) |
|
|
model.eval() |
|
|
|
|
|
|
|
|
if isinstance(text, str): |
|
|
texts = [text] |
|
|
else: |
|
|
texts = text |
|
|
|
|
|
results = [] |
|
|
|
|
|
for t in texts: |
|
|
input_ids = preprocess_text(t).to(device) |
|
|
|
|
|
with torch.no_grad(): |
|
|
logits = model(input_ids) |
|
|
prob = torch.sigmoid(logits).item() |
|
|
|
|
|
pred_class = "Non-Fiction" if prob > 0.5 else "Fiction" |
|
|
confidence = prob if prob > 0.5 else (1 - prob) |
|
|
|
|
|
results.append({ |
|
|
'text': t[:100] + '...' if len(t) > 100 else t, |
|
|
'prediction': pred_class, |
|
|
'confidence': confidence, |
|
|
'probability_nonfiction': prob |
|
|
}) |
|
|
|
|
|
return results[0] if isinstance(text, str) else results |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
sample_text = "The detective's coffee had gone cold hours ago, but she hardly noticed." |
|
|
|
|
|
|
|
|
model = TinyByteCNN.from_pretrained("fiction_model_output_cnn/best_model.pt") |
|
|
result = classify_text(sample_text, model) |
|
|
|
|
|
print(f"Text: {result['text']}") |
|
|
print(f"Prediction: {result['prediction']}") |
|
|
print(f"Confidence: {result['confidence']:.1%}") |