RavinduSen commited on
Commit
9800282
·
verified ·
1 Parent(s): 6264f9f

Upload 5 files

Browse files

Initial model upload — JaneGPT-v2 intent classifier

Files changed (5) hide show
  1. architecture.py +221 -0
  2. classifier.py +194 -0
  3. janegpt_v2_classifier.pt +3 -0
  4. requirements.txt +2 -0
  5. tokenizer.json +0 -0
architecture.py ADDED
@@ -0,0 +1,221 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ JaneGPT v2 Model Architecture
3
+
4
+ A lightweight decoder-only transformer with classification head
5
+ for intent classification. Features modern architecture components:
6
+ - Rotary Position Embeddings (RoPE)
7
+ - Grouped Query Attention (GQA)
8
+ - SwiGLU feed-forward networks
9
+ - RMSNorm
10
+
11
+ Created by Ravindu Senanayake
12
+ """
13
+
14
+ import math
15
+ import torch
16
+ import torch.nn as nn
17
+ import torch.nn.functional as F
18
+ from typing import Optional, Tuple
19
+
20
+
21
+ # Intent labels — exact order matters for classification
22
+ INTENT_LABELS = [
23
+ "volume_up", "volume_down", "volume_set", "volume_mute",
24
+ "brightness_up", "brightness_down", "brightness_set",
25
+ "media_play", "media_pause", "media_next", "media_previous",
26
+ "browser_search", "app_launch", "app_close", "app_switch",
27
+ "set_reminder", "screenshot", "read_screen", "explain_screen",
28
+ "undo", "chat", "quit_jane",
29
+ ]
30
+ INTENT_TO_ID = {label: i for i, label in enumerate(INTENT_LABELS)}
31
+ ID_TO_INTENT = {i: label for i, label in enumerate(INTENT_LABELS)}
32
+ NUM_INTENTS = len(INTENT_LABELS)
33
+
34
+
35
+ class RMSNorm(nn.Module):
36
+ """Root Mean Square Layer Normalization."""
37
+ def __init__(self, dim, eps=1e-6):
38
+ super().__init__()
39
+ self.weight = nn.Parameter(torch.ones(dim))
40
+ self.eps = eps
41
+
42
+ def forward(self, x):
43
+ rms = torch.sqrt(torch.mean(x ** 2, dim=-1, keepdim=True) + self.eps)
44
+ return (x / rms) * self.weight
45
+
46
+
47
+ class RotaryEmbedding(nn.Module):
48
+ """Rotary Position Embeddings (RoPE)."""
49
+ def __init__(self, head_dim, max_seq_len=512, theta=10000.0):
50
+ super().__init__()
51
+ inv_freq = 1.0 / (theta ** (torch.arange(0, head_dim, 2).float() / head_dim))
52
+ self.register_buffer('inv_freq', inv_freq)
53
+ t = torch.arange(max_seq_len).float()
54
+ freqs = torch.outer(t, inv_freq)
55
+ self.register_buffer('cos_cached', torch.cos(freqs))
56
+ self.register_buffer('sin_cached', torch.sin(freqs))
57
+
58
+ def forward(self, seq_len):
59
+ return self.cos_cached[:seq_len], self.sin_cached[:seq_len]
60
+
61
+
62
+ def apply_rope(x, cos, sin):
63
+ """Apply rotary position embeddings to input tensor."""
64
+ head_dim = x.shape[-1]
65
+ x1 = x[..., :head_dim // 2]
66
+ x2 = x[..., head_dim // 2:]
67
+ rotated = torch.cat([-x2, x1], dim=-1)
68
+ cos = cos.unsqueeze(0).unsqueeze(0).repeat(1, 1, 1, 2)
69
+ sin = sin.unsqueeze(0).unsqueeze(0).repeat(1, 1, 1, 2)
70
+ return x * cos + rotated * sin
71
+
72
+
73
+ class GroupedQueryAttention(nn.Module):
74
+ """
75
+ Grouped Query Attention (GQA).
76
+
77
+ Uses fewer KV heads than query heads for memory efficiency
78
+ while maintaining attention quality.
79
+ """
80
+ def __init__(self, embed_dim, num_heads, num_kv_heads, head_dim,
81
+ max_seq_len, dropout, rope_theta):
82
+ super().__init__()
83
+ self.num_heads = num_heads
84
+ self.num_kv_heads = num_kv_heads
85
+ self.head_dim = head_dim
86
+ self.num_groups = num_heads // num_kv_heads
87
+
88
+ self.q_proj = nn.Linear(embed_dim, num_heads * head_dim, bias=False)
89
+ self.k_proj = nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False)
90
+ self.v_proj = nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False)
91
+ self.out_proj = nn.Linear(embed_dim, embed_dim, bias=False)
92
+
93
+ self.dropout = nn.Dropout(dropout)
94
+ self.rope = RotaryEmbedding(head_dim, max_seq_len, rope_theta)
95
+
96
+ def forward(self, x):
97
+ batch_size, seq_len, _ = x.shape
98
+
99
+ q = self.q_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
100
+ k = self.k_proj(x).view(batch_size, seq_len, self.num_kv_heads, self.head_dim).transpose(1, 2)
101
+ v = self.v_proj(x).view(batch_size, seq_len, self.num_kv_heads, self.head_dim).transpose(1, 2)
102
+
103
+ cos, sin = self.rope(seq_len)
104
+ q = apply_rope(q, cos, sin)
105
+ k = apply_rope(k, cos, sin)
106
+
107
+ if self.num_groups > 1:
108
+ k = k.repeat_interleave(self.num_groups, dim=1)
109
+ v = v.repeat_interleave(self.num_groups, dim=1)
110
+
111
+ scores = (q @ k.transpose(-2, -1)) / math.sqrt(self.head_dim)
112
+ mask = torch.triu(torch.ones(seq_len, seq_len, device=x.device), diagonal=1).bool()
113
+ scores = scores.masked_fill(mask, float('-inf'))
114
+
115
+ attn_weights = torch.softmax(scores, dim=-1)
116
+ attn_weights = self.dropout(attn_weights)
117
+
118
+ out = attn_weights @ v
119
+ out = out.transpose(1, 2).contiguous().view(batch_size, seq_len, -1)
120
+ return self.out_proj(out)
121
+
122
+
123
+ class SwiGLUFeedForward(nn.Module):
124
+ """SwiGLU Feed-Forward Network."""
125
+ def __init__(self, embed_dim, ff_hidden, dropout=0.1):
126
+ super().__init__()
127
+ self.w1 = nn.Linear(embed_dim, ff_hidden, bias=False)
128
+ self.w2 = nn.Linear(ff_hidden, embed_dim, bias=False)
129
+ self.w3 = nn.Linear(embed_dim, ff_hidden, bias=False)
130
+ self.dropout = nn.Dropout(dropout)
131
+
132
+ def forward(self, x):
133
+ return self.dropout(self.w2(F.silu(self.w1(x)) * self.w3(x)))
134
+
135
+
136
+ class TransformerBlock(nn.Module):
137
+ """Single transformer block with GQA and SwiGLU."""
138
+ def __init__(self, embed_dim, num_heads, num_kv_heads, head_dim,
139
+ ff_hidden, max_seq_len, dropout, rope_theta):
140
+ super().__init__()
141
+ self.norm1 = RMSNorm(embed_dim)
142
+ self.norm2 = RMSNorm(embed_dim)
143
+ self.attn = GroupedQueryAttention(
144
+ embed_dim, num_heads, num_kv_heads, head_dim,
145
+ max_seq_len, dropout, rope_theta
146
+ )
147
+ self.ff = SwiGLUFeedForward(embed_dim, ff_hidden, dropout)
148
+
149
+ def forward(self, x):
150
+ x = x + self.attn(self.norm1(x))
151
+ x = x + self.ff(self.norm2(x))
152
+ return x
153
+
154
+
155
+ class JaneGPTv2Classifier(nn.Module):
156
+ """
157
+ JaneGPT v2 Intent Classifier.
158
+
159
+ A decoder-only transformer with a classification head
160
+ for 22-class intent classification.
161
+
162
+ Args:
163
+ vocab_size: Vocabulary size (default: 8192)
164
+ embed_dim: Embedding dimension (default: 256)
165
+ num_heads: Number of attention heads (default: 8)
166
+ num_kv_heads: Number of KV heads for GQA (default: 4)
167
+ num_layers: Number of transformer layers (default: 8)
168
+ ff_hidden: Feed-forward hidden dimension (default: 672)
169
+ max_seq_len: Maximum sequence length (default: 256)
170
+ dropout: Dropout rate (default: 0.1)
171
+ rope_theta: RoPE theta parameter (default: 10000.0)
172
+ """
173
+ def __init__(self, vocab_size=8192, embed_dim=256, num_heads=8,
174
+ num_kv_heads=4, num_layers=8, ff_hidden=672,
175
+ max_seq_len=256, dropout=0.1, rope_theta=10000.0):
176
+ super().__init__()
177
+
178
+ self.embed_dim = embed_dim
179
+ self.max_seq_len = max_seq_len
180
+
181
+ self.token_embedding = nn.Embedding(vocab_size, embed_dim)
182
+ head_dim = embed_dim // num_heads
183
+
184
+ self.layers = nn.ModuleList([
185
+ TransformerBlock(
186
+ embed_dim, num_heads, num_kv_heads, head_dim,
187
+ ff_hidden, max_seq_len, dropout, rope_theta
188
+ )
189
+ for _ in range(num_layers)
190
+ ])
191
+ self.norm = RMSNorm(embed_dim)
192
+ self.dropout = nn.Dropout(dropout)
193
+
194
+ self.intent_head = nn.Sequential(
195
+ nn.Linear(embed_dim, embed_dim),
196
+ nn.GELU(),
197
+ nn.Dropout(dropout),
198
+ nn.Linear(embed_dim, NUM_INTENTS),
199
+ )
200
+
201
+ def forward(self, x, labels=None):
202
+ x = self.dropout(self.token_embedding(x))
203
+ for layer in self.layers:
204
+ x = layer(x)
205
+ x = self.norm(x)
206
+
207
+ pooled = x[:, -1, :]
208
+ logits = self.intent_head(pooled)
209
+
210
+ loss = None
211
+ if labels is not None:
212
+ loss = F.cross_entropy(logits, labels)
213
+
214
+ return logits, loss
215
+
216
+ @torch.no_grad()
217
+ def predict(self, x):
218
+ logits, _ = self.forward(x)
219
+ probs = F.softmax(logits, dim=-1)
220
+ confidence, predicted = torch.max(probs, dim=-1)
221
+ return predicted.item(), confidence.item()
classifier.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ JaneGPT v2 Intent Classifier — Inference Wrapper
3
+
4
+ Simple interface for intent classification.
5
+
6
+ Usage:
7
+ from model.classifier import JaneGPTClassifier
8
+
9
+ classifier = JaneGPTClassifier()
10
+ intent, confidence = classifier.predict("turn up the volume")
11
+
12
+ Created by Ravindu Senanayake
13
+ """
14
+
15
+ from pathlib import Path
16
+ from typing import Optional, Dict, Tuple, List
17
+
18
+ import torch
19
+
20
+ from model.architecture import JaneGPTv2Classifier, ID_TO_INTENT, INTENT_LABELS
21
+
22
+
23
+ class JaneGPTClassifier:
24
+ """
25
+ Ready-to-use intent classifier.
26
+
27
+ Loads the trained model and tokenizer, provides simple
28
+ predict() interface for intent classification.
29
+
30
+ Args:
31
+ model_path: Path to trained checkpoint (.pt file)
32
+ tokenizer_path: Path to BPE tokenizer (.json file)
33
+ device: "auto", "cuda", or "cpu"
34
+ """
35
+
36
+ MAX_LEN = 128
37
+ PAD_ID = 0
38
+
39
+ def __init__(
40
+ self,
41
+ model_path: str = "weights/janegpt_v2_classifier.pt",
42
+ tokenizer_path: str = "weights/tokenizer.json",
43
+ device: str = "auto",
44
+ ):
45
+ self.model_path = Path(model_path)
46
+ self.tokenizer_path = Path(tokenizer_path)
47
+ self.is_ready = False
48
+
49
+ if device == "auto":
50
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
51
+ else:
52
+ self.device = torch.device(device)
53
+
54
+ self.tokenizer = None
55
+ self.model = None
56
+ self.id_to_intent = ID_TO_INTENT
57
+
58
+ self._load()
59
+
60
+ def _load(self):
61
+ """Load model and tokenizer."""
62
+ if not self.model_path.exists():
63
+ raise FileNotFoundError(f"Model not found: {self.model_path}")
64
+
65
+ if not self.tokenizer_path.exists():
66
+ raise FileNotFoundError(f"Tokenizer not found: {self.tokenizer_path}")
67
+
68
+ # Load tokenizer
69
+ from tokenizers import Tokenizer
70
+ self.tokenizer = Tokenizer.from_file(str(self.tokenizer_path))
71
+
72
+ # Load model
73
+ checkpoint = torch.load(
74
+ self.model_path, map_location=self.device, weights_only=False
75
+ )
76
+
77
+ config = checkpoint.get('config', {})
78
+
79
+ self.model = JaneGPTv2Classifier(
80
+ vocab_size=config.get('vocab_size', 8192),
81
+ embed_dim=config.get('embed_dim', 256),
82
+ num_heads=config.get('num_heads', 8),
83
+ num_kv_heads=config.get('num_kv_heads', 4),
84
+ num_layers=config.get('num_layers', 8),
85
+ ff_hidden=config.get('ff_hidden', 672),
86
+ max_seq_len=config.get('max_seq_len', 256),
87
+ dropout=config.get('dropout', 0.1),
88
+ rope_theta=config.get('rope_theta', 10000.0),
89
+ )
90
+
91
+ self.model.load_state_dict(checkpoint['model_state_dict'])
92
+ self.model.to(self.device)
93
+ self.model.eval()
94
+ self.is_ready = True
95
+
96
+ def _format_input(self, text: str, context: Optional[Dict] = None) -> str:
97
+ """Format input for the model."""
98
+ if context and context.get('last_intent'):
99
+ ctx_str = f"last_action={context['last_intent']}"
100
+ else:
101
+ ctx_str = "none"
102
+
103
+ return f"user: {text}\ncontext: {ctx_str}\njane:"
104
+
105
+ def _tokenize(self, text: str) -> torch.Tensor:
106
+ """Tokenize and pad to MAX_LEN."""
107
+ ids = self.tokenizer.encode(text).ids
108
+
109
+ if len(ids) > self.MAX_LEN:
110
+ ids = ids[:self.MAX_LEN]
111
+ else:
112
+ ids = ids + [self.PAD_ID] * (self.MAX_LEN - len(ids))
113
+
114
+ return torch.tensor([ids], dtype=torch.long, device=self.device)
115
+
116
+ def predict(
117
+ self,
118
+ text: str,
119
+ context: Optional[Dict] = None
120
+ ) -> Tuple[str, float]:
121
+ """
122
+ Predict intent for given text.
123
+
124
+ Args:
125
+ text: User utterance (e.g., "turn up the volume")
126
+ context: Optional dict with 'last_intent' key
127
+
128
+ Returns:
129
+ Tuple of (intent_label, confidence)
130
+
131
+ Example:
132
+ >>> classifier.predict("open chrome")
133
+ ('app_launch', 0.981)
134
+ """
135
+ if not self.is_ready:
136
+ raise RuntimeError("Model not loaded")
137
+
138
+ formatted = self._format_input(text, context)
139
+ input_ids = self._tokenize(formatted)
140
+
141
+ predicted_idx, confidence = self.model.predict(input_ids)
142
+ intent = self.id_to_intent.get(predicted_idx, 'chat')
143
+
144
+ return intent, confidence
145
+
146
+ def predict_top_k(
147
+ self,
148
+ text: str,
149
+ context: Optional[Dict] = None,
150
+ k: int = 3
151
+ ) -> List[Tuple[str, float]]:
152
+ """
153
+ Get top-k predictions with confidences.
154
+
155
+ Args:
156
+ text: User utterance
157
+ context: Optional context dict
158
+ k: Number of top predictions to return
159
+
160
+ Returns:
161
+ List of (intent_label, confidence) tuples
162
+
163
+ Example:
164
+ >>> classifier.predict_top_k("play something", k=3)
165
+ [('media_play', 0.85), ('browser_search', 0.08), ('chat', 0.03)]
166
+ """
167
+ if not self.is_ready:
168
+ raise RuntimeError("Model not loaded")
169
+
170
+ formatted = self._format_input(text, context)
171
+ input_ids = self._tokenize(formatted)
172
+
173
+ with torch.no_grad():
174
+ logits, _ = self.model(input_ids)
175
+ probs = torch.softmax(logits, dim=-1)
176
+ top_probs, top_indices = probs.topk(k, dim=-1)
177
+
178
+ return [
179
+ (self.id_to_intent.get(idx.item(), 'chat'), prob.item())
180
+ for prob, idx in zip(top_probs[0], top_indices[0])
181
+ ]
182
+
183
+ @staticmethod
184
+ def get_supported_intents() -> List[str]:
185
+ """Get list of all supported intent labels."""
186
+ return INTENT_LABELS.copy()
187
+
188
+ def __repr__(self):
189
+ return (
190
+ f"JaneGPTClassifier("
191
+ f"ready={self.is_ready}, "
192
+ f"device={self.device}, "
193
+ f"intents={len(INTENT_LABELS)})"
194
+ )
janegpt_v2_classifier.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:57136d3827707d96d34a52b615dcb74fbbbec53c86203787784386b0a110ced1
3
+ size 31513149
requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ torch>=2.0.0
2
+ tokenizers>=0.13.0
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff