Mitchins's picture
Upload folder using huggingface_hub
e78f911 verified
"""
TinyByteCNN Model for Fiction vs Non-Fiction Classification
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import unicodedata
import re
from typing import Union, List
class SE(nn.Module):
"""Squeeze-Excitation module"""
def __init__(self, c, r=8):
super().__init__()
m = max(c // r, 4)
self.fc1 = nn.Linear(c, m)
self.fc2 = nn.Linear(m, c)
def forward(self, x):
# x: [B, C, T]
s = x.mean(dim=-1) # [B, C]
s = F.silu(self.fc1(s))
s = torch.sigmoid(self.fc2(s)) # [B, C]
return x * s.unsqueeze(-1)
class SepResBlock(nn.Module):
"""Separable Residual Block with SE attention"""
def __init__(self, c_in, c_out, k=7, stride=1, dilation=1, use_gn=False, se_ratio=8, drop=0.0):
super().__init__()
Norm = (lambda c: nn.GroupNorm(32, c)) if use_gn else nn.BatchNorm1d
self.dw = nn.Conv1d(c_in, c_in, k, stride=stride, dilation=dilation,
padding=((k-1)//2)*dilation, groups=c_in, bias=False)
self.bn1 = Norm(c_in)
self.pw = nn.Conv1d(c_in, c_out, 1, bias=False)
self.bn2 = Norm(c_out)
self.se = SE(c_out, se_ratio)
self.drop = nn.Dropout(p=drop)
self.proj = None
if stride != 1 or c_in != c_out:
self.proj = nn.Conv1d(c_in, c_out, 1, stride=stride, bias=False)
def forward(self, x):
y = self.dw(x)
y = F.silu(self.bn1(y))
y = self.pw(y)
y = self.bn2(y)
y = self.se(y)
if self.proj is not None:
x = self.proj(x)
y = self.drop(y)
return F.silu(x + y)
class TinyByteCNN(nn.Module):
"""TinyByteCNN for Fiction vs Non-Fiction Classification"""
def __init__(self, config=None):
super().__init__()
# Default configuration
if config is None:
config = type('Config', (), {
'vocab_size': 256,
'embed_dim': 32,
'widths': [128, 192, 256, 320],
'use_gn': False,
'head_drop': 0.1,
'stochastic_depth': 0.05
})()
self.config = config
# Embedding layer for bytes
self.embed = nn.Embedding(config.vocab_size, config.embed_dim)
# Stem convolution
self.stem = nn.Conv1d(config.embed_dim, config.widths[0], 5, stride=2, padding=2, bias=False)
self.bn0 = nn.BatchNorm1d(config.widths[0]) if not config.use_gn else nn.GroupNorm(32, config.widths[0])
# Build stages
cfg = [
(2, config.widths[0], [1, 2]),
(2, config.widths[1], [1, 2]),
(3, config.widths[2], [1, 2, 4]),
(3, config.widths[3], [1, 2, 8])
]
stages = []
c_prev = config.widths[0]
for blocks, c, ds in cfg:
for i in range(blocks):
stride = 2 if i == 0 else 1
d = ds[i]
stages.append(SepResBlock(c_prev, c, k=7, stride=stride, dilation=d,
use_gn=config.use_gn, drop=config.stochastic_depth))
c_prev = c
self.stages = nn.Sequential(*stages)
# Classification head
self.head = nn.Sequential(
nn.Dropout(p=config.head_drop),
nn.Linear(2 * config.widths[-1], 1)
)
def forward(self, x_bytes):
"""
Args:
x_bytes: [B, T] uint8 tensor of byte values
Returns:
logits: [B] tensor of binary classification logits
"""
x = self.embed(x_bytes.long()) # [B, T, E]
x = x.transpose(1, 2).contiguous() # [B, E, T]
x = F.silu(self.bn0(self.stem(x))) # [B, C0, T/2]
x = self.stages(x) # [B, C, T/32]
# Global pooling
avg = x.mean(dim=-1)
mx = x.amax(dim=-1)
feats = torch.cat([avg, mx], dim=1)
logits = self.head(feats).squeeze(1)
return logits
@classmethod
def from_pretrained(cls, path_or_repo, use_safetensors=True):
"""Load pretrained model (supports both .bin and .safetensors)"""
import os
from pathlib import Path
# Determine if it's a file or directory/repo
if os.path.isdir(path_or_repo):
# Directory path - look for model files
base_path = Path(path_or_repo)
safetensors_path = base_path / "model.safetensors"
pytorch_path = base_path / "pytorch_model.bin"
if use_safetensors and safetensors_path.exists():
# Load from safetensors
from safetensors.torch import load_file
state_dict = load_file(str(safetensors_path))
# Load config if available
config_path = base_path / "config.json"
if config_path.exists():
import json
with open(config_path) as f:
config_dict = json.load(f)
config = type('Config', (), config_dict)()
else:
config = None
model = cls(config)
model.load_state_dict(state_dict)
return model
elif pytorch_path.exists():
checkpoint = torch.load(pytorch_path, weights_only=False, map_location='cpu')
elif os.path.isfile(path_or_repo):
if path_or_repo.endswith('.safetensors'):
from safetensors.torch import load_file
state_dict = load_file(path_or_repo)
model = cls()
model.load_state_dict(state_dict)
return model
else:
checkpoint = torch.load(path_or_repo, weights_only=False, map_location='cpu')
else:
# HuggingFace hub loading
from huggingface_hub import hf_hub_download
if use_safetensors:
try:
model_file = hf_hub_download(repo_id=path_or_repo, filename="model.safetensors")
from safetensors.torch import load_file
state_dict = load_file(model_file)
model = cls()
model.load_state_dict(state_dict)
return model
except:
pass # Fall back to pytorch format
model_file = hf_hub_download(repo_id=path_or_repo, filename="pytorch_model.bin")
checkpoint = torch.load(model_file, weights_only=False, map_location='cpu')
# Load from checkpoint (pytorch format)
if 'checkpoint' in locals():
config = checkpoint.get('config', None)
model = cls(config)
state_dict = checkpoint.get('model_state_dict', checkpoint)
model.load_state_dict(state_dict)
return model
def save_pretrained(self, save_path):
"""Save model to directory"""
import os
os.makedirs(save_path, exist_ok=True)
torch.save({
'model_state_dict': self.state_dict(),
'config': self.config
}, os.path.join(save_path, 'pytorch_model.bin'))
def preprocess_text(text: str, max_len: int = 4096) -> torch.Tensor:
"""
Preprocess text to bytes for model input
Args:
text: Input text string
max_len: Maximum sequence length (default 4096)
Returns:
Tensor of shape [1, max_len] containing byte values
"""
# Unicode NFC normalize
text = unicodedata.normalize('NFC', text)
# Replace \r\n → \n
text = text.replace('\r\n', '\n')
# Collapse runs of whitespace to at most 2
text = re.sub(r'\s{3,}', ' ', text)
# Convert to bytes
text_bytes = text.encode('utf-8', errors='ignore')
# Pad or truncate to max_len
input_ids = np.zeros(max_len, dtype=np.uint8)
input_ids[:min(len(text_bytes), max_len)] = list(text_bytes[:max_len])
return torch.from_numpy(input_ids).unsqueeze(0) # Add batch dimension
def classify_text(text: Union[str, List[str]], model=None, device='cpu'):
"""
Classify text as fiction or non-fiction
Args:
text: Single string or list of strings to classify
model: Pre-loaded model (optional)
device: Device to run on ('cpu', 'cuda', 'mps')
Returns:
Dictionary with predictions and confidence scores
"""
if model is None:
model = TinyByteCNN.from_pretrained("fiction_classifier_hf")
model = model.to(device)
model.eval()
# Handle single text or batch
if isinstance(text, str):
texts = [text]
else:
texts = text
results = []
for t in texts:
input_ids = preprocess_text(t).to(device)
with torch.no_grad():
logits = model(input_ids)
prob = torch.sigmoid(logits).item()
pred_class = "Non-Fiction" if prob > 0.5 else "Fiction"
confidence = prob if prob > 0.5 else (1 - prob)
results.append({
'text': t[:100] + '...' if len(t) > 100 else t,
'prediction': pred_class,
'confidence': confidence,
'probability_nonfiction': prob
})
return results[0] if isinstance(text, str) else results
if __name__ == "__main__":
# Example usage
sample_text = "The detective's coffee had gone cold hours ago, but she hardly noticed."
# Load and use model
model = TinyByteCNN.from_pretrained("fiction_model_output_cnn/best_model.pt")
result = classify_text(sample_text, model)
print(f"Text: {result['text']}")
print(f"Prediction: {result['prediction']}")
print(f"Confidence: {result['confidence']:.1%}")