File size: 1,765 Bytes
b6b48db
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
---
language: en
license: mit
library_name: pytorch
tags:
- shakespeare
- text-classification
- bert
- pytorch
- nlp
datasets:
- lanretto/shakespeare-vs-modern-dialogue
---

# 🎭 Shakespeare Authenticator - PyTorch Implementation

## Model Description

This is a **PyTorch manual implementation** of the Shakespeare Authenticator model, distinguishing between authentic Shakespearean text and modern writing. This model was built from scratch using raw PyTorch (without Hugging Face Trainer) for educational purposes.

## Model Performance

| Metric | Value |
|--------|-------|
| **Accuracy** | 0.9835 |
| **F1-Score** | 0.9685 |
| **Test Samples** | 40,626 |
| **Avg Confidence** | 0.9938 |

### Comparison with Original Implementation

| Model | Accuracy | F1-Score |
|-------|----------|----------|
| Original (HF Trainer) | 0.9820 | 0.9658 |
| **PyTorch Manual** | **0.9835** | **0.9685** |

## Training Details

- **Architecture**: BERT-base + Custom Classification Head
- **Training Approach**: Manual PyTorch training loop
- **Learning Rates**: BERT (2e-5), Classifier (1e-4)
- **Epochs**: 3
- **Batch Size**: 128
- **Best Epoch**: 3
- **Best Validation Accuracy**: 0.9849

## Model Architecture

```python
class ShakespeareClassifier(nn.Module):
    def __init__(self, bert_model, num_classes=2, dropout_rate=0.1):
        super().__init__()
        self.bert = bert_model
        self.dropout = nn.Dropout(dropout_rate)
        self.classifier = nn.Linear(768, num_classes)

    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        pooled_output = outputs.pooler_output
        x = self.dropout(pooled_output)
        logits = self.classifier(x)
        return logits
```