TreeLSTM - DEEPSEEK - Classification (3 classes)
Toxicity prediction model trained on the DEEPSEEK dataset.
| Property | Value |
|---|---|
| Model | TreeLSTM |
| Task | Classification (3 classes) |
| Dataset | deepseek |
| Framework | PyTorch / PyTorch Lightning |
Class: TreeLSTMModel
from src.models.lstm import TreeLSTMModel
model = TreeLSTMModel(
vocab: dict, # Token-to-index mapping
hidden_dim: int = 256, # Hidden dimension
num_classes: int = 1, # 1=regression, 2+=classification
dropout: float = 0.3,
lr: float = 0.001,
loss_fn: str = 'auto'
)
Methods
| Method | Description |
|---|---|
forward(batch) |
Process batch of tree graphs. Returns logits. |
compute_loss(logits, targets) |
Computes loss based on num_classes. |
load_from_checkpoint(path) |
Load model from checkpoint. |
Tree Input Format
The model expects constituency trees converted to a PyTorch Geometric batch with:
x: Node features (token embeddings)edge_index: Parent-child edgesbatch: Batch assignment for each node
Usage with ToxicThesis (Recommended)
# 1. Clone ToxicThesis repository
# git clone https://github.com/simo-corbo/ToxicThesis
# cd ToxicThesis && pip install -r requirements.txt
from huggingface_hub import snapshot_download
import torch
import pickle
# 2. Download model files
model_dir = snapshot_download(
repo_id="simocorbo/toxicthesis-deepseek-tree-classification-3",
allow_patterns=["checkpoints/*", "*.pkl"]
)
# 3. Load vocabulary
with open(f"{model_dir}/vocab_stanza_hybrid.pkl", 'rb') as f:
vocab = pickle.load(f)
# 4. Import and load model from ToxicThesis
from src.models.lstm import TreeLSTMModel
checkpoint = torch.load(f"{model_dir}/checkpoints/best.pt", map_location='cpu')
hparams = checkpoint.get('hyper_parameters', {})
model = TreeLSTMModel(
vocab=vocab,
hidden_dim=hparams.get('hidden_dim', 256),
num_classes=hparams.get('num_classes', 3)
)
model.load_state_dict(checkpoint.get('state_dict', checkpoint), strict=False)
model.eval()
# 5. For inference, use the preprocessing pipeline to convert text to tree format
from src.preprocessing.tree_utils import text_to_tree_batch
# Parse and convert text to tree batch
tree_batch = text_to_tree_batch("Your text here", vocab)
with torch.no_grad():
logits = model(tree_batch)
if model.num_classes == 1:
score = torch.sigmoid(logits).item()
print(f"Score: {score}")
else:
probs = torch.softmax(logits, dim=-1)
print(f"Probabilities: {probs}")
Note on Standalone Usage
TreeLSTM requires constituency parsing and tree-to-graph conversion. For standalone usage, you would need to implement the tree preprocessing pipeline. We recommend using ToxicThesis directly.
Score Interpretation
| Output | Range | Meaning |
|---|---|---|
probabilities |
List[float] | Probability distribution over 3 classes. |
class |
0 to 2 | Predicted class (argmax of probabilities). |
Classes: 3 toxicity levels, where higher class index = more toxic.
Files
| File | Description |
|---|---|
checkpoints/best.pt |
Model checkpoint (best validation loss) |
hparams.yaml |
Hyperparameters used for training |
train.csv |
Training metrics per epoch |
val.csv |
Validation metrics per epoch |
vocab_stanza_hybrid.pkl |
Vocabulary (for tree-based models) |
Installation
# Clone ToxicThesis for full model implementations
git clone https://github.com/simo-corbo/ToxicThesis
cd ToxicThesis
pip install -r requirements.txt
# Or install dependencies directly
pip install torch transformers huggingface_hub fasttext-wheel stanza
Citation
@software{toxicthesis2025,
title={ToxicThesis},
author={Corbo, Simone},
year={2025},
url={https://github.com/simo-corbo/ToxicThesis}
}