leonvanbokhorst commited on
Commit
80a4e16
·
verified ·
1 Parent(s): 9a30c1f

Upload README.md with huggingface_hub

Browse files
Files changed (1) hide show
  1. README.md +44 -84
README.md CHANGED
@@ -23,56 +23,46 @@ model-index:
23
  metrics:
24
  - name: Test RMSE
25
  type: rmse
26
- value: 0.0144
27
  - name: Test R²
28
  type: r2
29
- value: 0.8666
30
- - name: Test Loss
31
- type: loss
32
- value: 0.0002
33
  ---
34
 
35
  # Topic Drift Detector Model
36
 
37
- ## Version: v20241226_110212
38
 
39
- This model detects topic drift in conversations using a streamlined attention-based architecture. Trained on the [leonvanbokhorst/topic-drift-v2](https://huggingface.co/datasets/leonvanbokhorst/topic-drift-v2) dataset.
40
 
41
  ## Model Architecture
42
- - Efficient single-layer attention mechanism
43
- - Direct pattern recognition
44
- - Streamlined processing pipeline
45
- - Optimized scaling factor (4.0)
46
- - PreNorm layers with residual connections
47
 
48
  ### Key Components:
49
- 1. **Embedding Processor**:
50
- - Input dimension: 1024
51
  - Hidden dimension: 512
52
- - Dropout rate: 0.35
53
- - PreNorm layers with residual connections
54
 
55
  2. **Attention Block**:
56
- - Single attention layer
57
- - Feed-forward dimension: 512
58
- - Learned position encodings
 
 
 
 
59
  - Residual connections
60
 
61
- 3. **Pattern Recognition**:
62
- - Direct feature extraction
63
- - Efficient tensor operations
64
- - Optimized memory usage
65
 
66
  ## Performance Metrics
67
  ```txt
68
- === Full Training Results ===
69
- Best Validation RMSE: 0.0142
70
- Best Validation R²: 0.8711
71
-
72
  === Test Set Results ===
73
- Loss: 0.0002
74
- RMSE: 0.0144
75
- R²: 0.8666
76
  ```
77
 
78
  ## Training Details
@@ -85,58 +75,33 @@ R²: 0.8666
85
  - Target standard deviation: 0.2
86
  - Base embeddings: BAAI/bge-m3
87
 
88
- ## Key Improvements
89
- 1. **Simplified Architecture**:
90
- - Reduced complexity
91
- - Focused pattern detection
92
- - Efficient processing
93
- - Optimized memory usage
94
-
95
- 2. **Performance Benefits**:
96
- - Improved RMSE (0.0144)
97
- - Strong R² score (0.8666)
98
- - Consistent predictions
99
- - Wide score range
100
-
101
  ## Usage Example
102
 
103
- To use the model, first install the required packages:
104
- ```bash
105
  pip install torch transformers huggingface_hub
106
- ```
107
 
108
- Then use the following code:
109
- ```python
110
  import torch
111
  from transformers import AutoModel, AutoTokenizer
112
  from huggingface_hub import hf_hub_download
113
 
114
- def load_model(repo_id: str = "leonvanbokhorst/topic-drift-detector"):
115
- # Download latest model weights
116
- model_path = hf_hub_download(
117
- repo_id=repo_id,
118
- filename="models/latest/topic_drift_model.pt"
119
- )
120
-
121
- # Load checkpoint
122
- checkpoint = torch.load(model_path, weights_only=True)
123
-
124
- # Create model with same hyperparameters
125
- model = EnhancedTopicDriftDetector(
126
- input_dim=1024, # BGE-M3 embedding dimension
127
- hidden_dim=checkpoint['hyperparameters']['hidden_dim']
128
- )
129
-
130
- # Load state dict
131
- model.load_state_dict(checkpoint['model_state_dict'])
132
- return model
133
-
134
- # Load base embedding model
135
- base_model = AutoModel.from_pretrained('BAAI/bge-m3')
136
  tokenizer = AutoTokenizer.from_pretrained('BAAI/bge-m3')
137
 
138
- # Load topic drift detector from Hugging Face
139
- model = load_model()
 
 
 
 
 
 
 
 
 
140
  model.eval()
141
 
142
  # Example conversation
@@ -151,19 +116,17 @@ conversation = [
151
  "I couldn't believe that last-minute goal."
152
  ]
153
 
154
- # Get embeddings
155
  with torch.no_grad():
 
156
  inputs = tokenizer(conversation, padding=True, truncation=True, return_tensors='pt')
157
- embeddings = base_model(**inputs).last_hidden_state.mean(dim=1) # [8, 1024]
158
-
159
- # Reshape for model input [1, 8*1024]
160
- conversation_embeddings = embeddings.view(1, -1)
161
 
162
  # Get drift score
163
- drift_scores = model(conversation_embeddings)
164
-
165
- print(f"Topic drift score: {drift_scores.item():.4f}")
166
- # Higher scores indicate more topic drift
167
  ```
168
 
169
  ## Limitations
@@ -171,6 +134,3 @@ print(f"Topic drift score: {drift_scores.item():.4f}")
171
  - Requires exactly 8 turns of conversation
172
  - Each turn should be between 1-512 tokens
173
  - Relies on BAAI/bge-m3 embeddings
174
-
175
- ## Training Curves
176
- ![Training Curves](plots/v20241226_110212/training_curves.png)
 
23
  metrics:
24
  - name: Test RMSE
25
  type: rmse
26
+ value: 0.0165
27
  - name: Test R²
28
  type: r2
29
+ value: 0.0165
 
 
 
30
  ---
31
 
32
  # Topic Drift Detector Model
33
 
34
+ ## Version: v20241226_114030
35
 
36
+ This model detects topic drift in conversations using an efficient attention-based architecture. Trained on the [leonvanbokhorst/topic-drift-v2](https://huggingface.co/datasets/leonvanbokhorst/topic-drift-v2) dataset.
37
 
38
  ## Model Architecture
 
 
 
 
 
39
 
40
  ### Key Components:
41
+ 1. **Input Processing**:
42
+ - Input dimension: 1024 (BGE-M3 embeddings)
43
  - Hidden dimension: 512
44
+ - Sequence length: 8 turns
 
45
 
46
  2. **Attention Block**:
47
+ - Multi-head attention (4 heads)
48
+ - PreNorm layers with residual connections
49
+ - Dropout rate: 0.1
50
+
51
+ 3. **Feed-Forward Network**:
52
+ - Two-layer MLP with GELU activation
53
+ - Hidden dimension: 512 -> 2048 -> 512
54
  - Residual connections
55
 
56
+ 4. **Output Layer**:
57
+ - Two-layer MLP: 512 -> 256 -> 1
58
+ - GELU activation
59
+ - Direct sigmoid output for [0,1] range
60
 
61
  ## Performance Metrics
62
  ```txt
 
 
 
 
63
  === Test Set Results ===
64
+ RMSE: 0.0165
65
+ R��: 0.0165
 
66
  ```
67
 
68
  ## Training Details
 
75
  - Target standard deviation: 0.2
76
  - Base embeddings: BAAI/bge-m3
77
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
  ## Usage Example
79
 
80
+ ```python
81
+ # Install dependencies
82
  pip install torch transformers huggingface_hub
 
83
 
84
+ # Import required packages
 
85
  import torch
86
  from transformers import AutoModel, AutoTokenizer
87
  from huggingface_hub import hf_hub_download
88
 
89
+ # Load base model and tokenizer
90
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
91
+ base_model = AutoModel.from_pretrained('BAAI/bge-m3').to(device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
  tokenizer = AutoTokenizer.from_pretrained('BAAI/bge-m3')
93
 
94
+ # Download and load topic drift model
95
+ model_path = hf_hub_download(
96
+ repo_id='leonvanbokhorst/topic-drift-detector',
97
+ filename='models/v20241226_114030/topic_drift_model.pt'
98
+ )
99
+ checkpoint = torch.load(model_path, weights_only=True, map_location=device)
100
+ model = EnhancedTopicDriftDetector(
101
+ input_dim=1024,
102
+ hidden_dim=checkpoint['hyperparameters']['hidden_dim']
103
+ ).to(device)
104
+ model.load_state_dict(checkpoint['model_state_dict'])
105
  model.eval()
106
 
107
  # Example conversation
 
116
  "I couldn't believe that last-minute goal."
117
  ]
118
 
119
+ # Process conversation
120
  with torch.no_grad():
121
+ # Get embeddings
122
  inputs = tokenizer(conversation, padding=True, truncation=True, return_tensors='pt')
123
+ inputs = dict((k, v.to(device)) for k, v in inputs.items())
124
+ embeddings = base_model(**inputs).last_hidden_state.mean(dim=1)
 
 
125
 
126
  # Get drift score
127
+ conversation_embeddings = embeddings.view(1, -1)
128
+ drift_score = model(conversation_embeddings)
129
+ print(f"Topic drift score: {drift_score.item():.4f}")
 
130
  ```
131
 
132
  ## Limitations
 
134
  - Requires exactly 8 turns of conversation
135
  - Each turn should be between 1-512 tokens
136
  - Relies on BAAI/bge-m3 embeddings