Upload README.md with huggingface_hub
Browse files
README.md
CHANGED
|
@@ -23,56 +23,46 @@ model-index:
|
|
| 23 |
metrics:
|
| 24 |
- name: Test RMSE
|
| 25 |
type: rmse
|
| 26 |
-
value: 0.
|
| 27 |
- name: Test R²
|
| 28 |
type: r2
|
| 29 |
-
value: 0.
|
| 30 |
-
- name: Test Loss
|
| 31 |
-
type: loss
|
| 32 |
-
value: 0.0002
|
| 33 |
---
|
| 34 |
|
| 35 |
# Topic Drift Detector Model
|
| 36 |
|
| 37 |
-
## Version:
|
| 38 |
|
| 39 |
-
This model detects topic drift in conversations using
|
| 40 |
|
| 41 |
## Model Architecture
|
| 42 |
-
- Efficient single-layer attention mechanism
|
| 43 |
-
- Direct pattern recognition
|
| 44 |
-
- Streamlined processing pipeline
|
| 45 |
-
- Optimized scaling factor (4.0)
|
| 46 |
-
- PreNorm layers with residual connections
|
| 47 |
|
| 48 |
### Key Components:
|
| 49 |
-
1. **
|
| 50 |
-
- Input dimension: 1024
|
| 51 |
- Hidden dimension: 512
|
| 52 |
-
-
|
| 53 |
-
- PreNorm layers with residual connections
|
| 54 |
|
| 55 |
2. **Attention Block**:
|
| 56 |
-
-
|
| 57 |
-
-
|
| 58 |
-
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 59 |
- Residual connections
|
| 60 |
|
| 61 |
-
|
| 62 |
-
-
|
| 63 |
-
-
|
| 64 |
-
-
|
| 65 |
|
| 66 |
## Performance Metrics
|
| 67 |
```txt
|
| 68 |
-
=== Full Training Results ===
|
| 69 |
-
Best Validation RMSE: 0.0142
|
| 70 |
-
Best Validation R²: 0.8711
|
| 71 |
-
|
| 72 |
=== Test Set Results ===
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
R²: 0.8666
|
| 76 |
```
|
| 77 |
|
| 78 |
## Training Details
|
|
@@ -85,58 +75,33 @@ R²: 0.8666
|
|
| 85 |
- Target standard deviation: 0.2
|
| 86 |
- Base embeddings: BAAI/bge-m3
|
| 87 |
|
| 88 |
-
## Key Improvements
|
| 89 |
-
1. **Simplified Architecture**:
|
| 90 |
-
- Reduced complexity
|
| 91 |
-
- Focused pattern detection
|
| 92 |
-
- Efficient processing
|
| 93 |
-
- Optimized memory usage
|
| 94 |
-
|
| 95 |
-
2. **Performance Benefits**:
|
| 96 |
-
- Improved RMSE (0.0144)
|
| 97 |
-
- Strong R² score (0.8666)
|
| 98 |
-
- Consistent predictions
|
| 99 |
-
- Wide score range
|
| 100 |
-
|
| 101 |
## Usage Example
|
| 102 |
|
| 103 |
-
|
| 104 |
-
|
| 105 |
pip install torch transformers huggingface_hub
|
| 106 |
-
```
|
| 107 |
|
| 108 |
-
|
| 109 |
-
```python
|
| 110 |
import torch
|
| 111 |
from transformers import AutoModel, AutoTokenizer
|
| 112 |
from huggingface_hub import hf_hub_download
|
| 113 |
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
repo_id=repo_id,
|
| 118 |
-
filename="models/latest/topic_drift_model.pt"
|
| 119 |
-
)
|
| 120 |
-
|
| 121 |
-
# Load checkpoint
|
| 122 |
-
checkpoint = torch.load(model_path, weights_only=True)
|
| 123 |
-
|
| 124 |
-
# Create model with same hyperparameters
|
| 125 |
-
model = EnhancedTopicDriftDetector(
|
| 126 |
-
input_dim=1024, # BGE-M3 embedding dimension
|
| 127 |
-
hidden_dim=checkpoint['hyperparameters']['hidden_dim']
|
| 128 |
-
)
|
| 129 |
-
|
| 130 |
-
# Load state dict
|
| 131 |
-
model.load_state_dict(checkpoint['model_state_dict'])
|
| 132 |
-
return model
|
| 133 |
-
|
| 134 |
-
# Load base embedding model
|
| 135 |
-
base_model = AutoModel.from_pretrained('BAAI/bge-m3')
|
| 136 |
tokenizer = AutoTokenizer.from_pretrained('BAAI/bge-m3')
|
| 137 |
|
| 138 |
-
#
|
| 139 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 140 |
model.eval()
|
| 141 |
|
| 142 |
# Example conversation
|
|
@@ -151,19 +116,17 @@ conversation = [
|
|
| 151 |
"I couldn't believe that last-minute goal."
|
| 152 |
]
|
| 153 |
|
| 154 |
-
#
|
| 155 |
with torch.no_grad():
|
|
|
|
| 156 |
inputs = tokenizer(conversation, padding=True, truncation=True, return_tensors='pt')
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
# Reshape for model input [1, 8*1024]
|
| 160 |
-
conversation_embeddings = embeddings.view(1, -1)
|
| 161 |
|
| 162 |
# Get drift score
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
print(f"Topic drift score: {
|
| 166 |
-
# Higher scores indicate more topic drift
|
| 167 |
```
|
| 168 |
|
| 169 |
## Limitations
|
|
@@ -171,6 +134,3 @@ print(f"Topic drift score: {drift_scores.item():.4f}")
|
|
| 171 |
- Requires exactly 8 turns of conversation
|
| 172 |
- Each turn should be between 1-512 tokens
|
| 173 |
- Relies on BAAI/bge-m3 embeddings
|
| 174 |
-
|
| 175 |
-
## Training Curves
|
| 176 |
-

|
|
|
|
| 23 |
metrics:
|
| 24 |
- name: Test RMSE
|
| 25 |
type: rmse
|
| 26 |
+
value: 0.0165
|
| 27 |
- name: Test R²
|
| 28 |
type: r2
|
| 29 |
+
value: 0.0165
|
|
|
|
|
|
|
|
|
|
| 30 |
---
|
| 31 |
|
| 32 |
# Topic Drift Detector Model
|
| 33 |
|
| 34 |
+
## Version: v20241226_114030
|
| 35 |
|
| 36 |
+
This model detects topic drift in conversations using an efficient attention-based architecture. Trained on the [leonvanbokhorst/topic-drift-v2](https://huggingface.co/datasets/leonvanbokhorst/topic-drift-v2) dataset.
|
| 37 |
|
| 38 |
## Model Architecture
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 39 |
|
| 40 |
### Key Components:
|
| 41 |
+
1. **Input Processing**:
|
| 42 |
+
- Input dimension: 1024 (BGE-M3 embeddings)
|
| 43 |
- Hidden dimension: 512
|
| 44 |
+
- Sequence length: 8 turns
|
|
|
|
| 45 |
|
| 46 |
2. **Attention Block**:
|
| 47 |
+
- Multi-head attention (4 heads)
|
| 48 |
+
- PreNorm layers with residual connections
|
| 49 |
+
- Dropout rate: 0.1
|
| 50 |
+
|
| 51 |
+
3. **Feed-Forward Network**:
|
| 52 |
+
- Two-layer MLP with GELU activation
|
| 53 |
+
- Hidden dimension: 512 -> 2048 -> 512
|
| 54 |
- Residual connections
|
| 55 |
|
| 56 |
+
4. **Output Layer**:
|
| 57 |
+
- Two-layer MLP: 512 -> 256 -> 1
|
| 58 |
+
- GELU activation
|
| 59 |
+
- Direct sigmoid output for [0,1] range
|
| 60 |
|
| 61 |
## Performance Metrics
|
| 62 |
```txt
|
|
|
|
|
|
|
|
|
|
|
|
|
| 63 |
=== Test Set Results ===
|
| 64 |
+
RMSE: 0.0165
|
| 65 |
+
R��: 0.0165
|
|
|
|
| 66 |
```
|
| 67 |
|
| 68 |
## Training Details
|
|
|
|
| 75 |
- Target standard deviation: 0.2
|
| 76 |
- Base embeddings: BAAI/bge-m3
|
| 77 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 78 |
## Usage Example
|
| 79 |
|
| 80 |
+
```python
|
| 81 |
+
# Install dependencies
|
| 82 |
pip install torch transformers huggingface_hub
|
|
|
|
| 83 |
|
| 84 |
+
# Import required packages
|
|
|
|
| 85 |
import torch
|
| 86 |
from transformers import AutoModel, AutoTokenizer
|
| 87 |
from huggingface_hub import hf_hub_download
|
| 88 |
|
| 89 |
+
# Load base model and tokenizer
|
| 90 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 91 |
+
base_model = AutoModel.from_pretrained('BAAI/bge-m3').to(device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 92 |
tokenizer = AutoTokenizer.from_pretrained('BAAI/bge-m3')
|
| 93 |
|
| 94 |
+
# Download and load topic drift model
|
| 95 |
+
model_path = hf_hub_download(
|
| 96 |
+
repo_id='leonvanbokhorst/topic-drift-detector',
|
| 97 |
+
filename='models/v20241226_114030/topic_drift_model.pt'
|
| 98 |
+
)
|
| 99 |
+
checkpoint = torch.load(model_path, weights_only=True, map_location=device)
|
| 100 |
+
model = EnhancedTopicDriftDetector(
|
| 101 |
+
input_dim=1024,
|
| 102 |
+
hidden_dim=checkpoint['hyperparameters']['hidden_dim']
|
| 103 |
+
).to(device)
|
| 104 |
+
model.load_state_dict(checkpoint['model_state_dict'])
|
| 105 |
model.eval()
|
| 106 |
|
| 107 |
# Example conversation
|
|
|
|
| 116 |
"I couldn't believe that last-minute goal."
|
| 117 |
]
|
| 118 |
|
| 119 |
+
# Process conversation
|
| 120 |
with torch.no_grad():
|
| 121 |
+
# Get embeddings
|
| 122 |
inputs = tokenizer(conversation, padding=True, truncation=True, return_tensors='pt')
|
| 123 |
+
inputs = dict((k, v.to(device)) for k, v in inputs.items())
|
| 124 |
+
embeddings = base_model(**inputs).last_hidden_state.mean(dim=1)
|
|
|
|
|
|
|
| 125 |
|
| 126 |
# Get drift score
|
| 127 |
+
conversation_embeddings = embeddings.view(1, -1)
|
| 128 |
+
drift_score = model(conversation_embeddings)
|
| 129 |
+
print(f"Topic drift score: {drift_score.item():.4f}")
|
|
|
|
| 130 |
```
|
| 131 |
|
| 132 |
## Limitations
|
|
|
|
| 134 |
- Requires exactly 8 turns of conversation
|
| 135 |
- Each turn should be between 1-512 tokens
|
| 136 |
- Relies on BAAI/bge-m3 embeddings
|
|
|
|
|
|
|
|
|