JaneGPT-v2 / model /architecture.py
RavinduSen's picture
Model + Tokenizer
50771e3 verified
"""
JaneGPT v2 Model Architecture
A lightweight decoder-only transformer with classification head
for intent classification. Features modern architecture components:
- Rotary Position Embeddings (RoPE)
- Grouped Query Attention (GQA)
- SwiGLU feed-forward networks
- RMSNorm
Created by Ravindu Senanayake
"""
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, Tuple
# Intent labels — exact order matters for classification
INTENT_LABELS = [
"volume_up", "volume_down", "volume_set", "volume_mute",
"brightness_up", "brightness_down", "brightness_set",
"media_play", "media_pause", "media_next", "media_previous",
"browser_search", "app_launch", "app_close", "app_switch",
"set_reminder", "screenshot", "read_screen", "explain_screen",
"undo", "chat", "quit_jane",
]
INTENT_TO_ID = {label: i for i, label in enumerate(INTENT_LABELS)}
ID_TO_INTENT = {i: label for i, label in enumerate(INTENT_LABELS)}
NUM_INTENTS = len(INTENT_LABELS)
class RMSNorm(nn.Module):
"""Root Mean Square Layer Normalization."""
def __init__(self, dim, eps=1e-6):
super().__init__()
self.weight = nn.Parameter(torch.ones(dim))
self.eps = eps
def forward(self, x):
rms = torch.sqrt(torch.mean(x ** 2, dim=-1, keepdim=True) + self.eps)
return (x / rms) * self.weight
class RotaryEmbedding(nn.Module):
"""Rotary Position Embeddings (RoPE)."""
def __init__(self, head_dim, max_seq_len=512, theta=10000.0):
super().__init__()
inv_freq = 1.0 / (theta ** (torch.arange(0, head_dim, 2).float() / head_dim))
self.register_buffer('inv_freq', inv_freq)
t = torch.arange(max_seq_len).float()
freqs = torch.outer(t, inv_freq)
self.register_buffer('cos_cached', torch.cos(freqs))
self.register_buffer('sin_cached', torch.sin(freqs))
def forward(self, seq_len):
return self.cos_cached[:seq_len], self.sin_cached[:seq_len]
def apply_rope(x, cos, sin):
"""Apply rotary position embeddings to input tensor."""
head_dim = x.shape[-1]
x1 = x[..., :head_dim // 2]
x2 = x[..., head_dim // 2:]
rotated = torch.cat([-x2, x1], dim=-1)
cos = cos.unsqueeze(0).unsqueeze(0).repeat(1, 1, 1, 2)
sin = sin.unsqueeze(0).unsqueeze(0).repeat(1, 1, 1, 2)
return x * cos + rotated * sin
class GroupedQueryAttention(nn.Module):
"""
Grouped Query Attention (GQA).
Uses fewer KV heads than query heads for memory efficiency
while maintaining attention quality.
"""
def __init__(self, embed_dim, num_heads, num_kv_heads, head_dim,
max_seq_len, dropout, rope_theta):
super().__init__()
self.num_heads = num_heads
self.num_kv_heads = num_kv_heads
self.head_dim = head_dim
self.num_groups = num_heads // num_kv_heads
self.q_proj = nn.Linear(embed_dim, num_heads * head_dim, bias=False)
self.k_proj = nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False)
self.v_proj = nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False)
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=False)
self.dropout = nn.Dropout(dropout)
self.rope = RotaryEmbedding(head_dim, max_seq_len, rope_theta)
def forward(self, x):
batch_size, seq_len, _ = x.shape
q = self.q_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
k = self.k_proj(x).view(batch_size, seq_len, self.num_kv_heads, self.head_dim).transpose(1, 2)
v = self.v_proj(x).view(batch_size, seq_len, self.num_kv_heads, self.head_dim).transpose(1, 2)
cos, sin = self.rope(seq_len)
q = apply_rope(q, cos, sin)
k = apply_rope(k, cos, sin)
if self.num_groups > 1:
k = k.repeat_interleave(self.num_groups, dim=1)
v = v.repeat_interleave(self.num_groups, dim=1)
scores = (q @ k.transpose(-2, -1)) / math.sqrt(self.head_dim)
mask = torch.triu(torch.ones(seq_len, seq_len, device=x.device), diagonal=1).bool()
scores = scores.masked_fill(mask, float('-inf'))
attn_weights = torch.softmax(scores, dim=-1)
attn_weights = self.dropout(attn_weights)
out = attn_weights @ v
out = out.transpose(1, 2).contiguous().view(batch_size, seq_len, -1)
return self.out_proj(out)
class SwiGLUFeedForward(nn.Module):
"""SwiGLU Feed-Forward Network."""
def __init__(self, embed_dim, ff_hidden, dropout=0.1):
super().__init__()
self.w1 = nn.Linear(embed_dim, ff_hidden, bias=False)
self.w2 = nn.Linear(ff_hidden, embed_dim, bias=False)
self.w3 = nn.Linear(embed_dim, ff_hidden, bias=False)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
return self.dropout(self.w2(F.silu(self.w1(x)) * self.w3(x)))
class TransformerBlock(nn.Module):
"""Single transformer block with GQA and SwiGLU."""
def __init__(self, embed_dim, num_heads, num_kv_heads, head_dim,
ff_hidden, max_seq_len, dropout, rope_theta):
super().__init__()
self.norm1 = RMSNorm(embed_dim)
self.norm2 = RMSNorm(embed_dim)
self.attn = GroupedQueryAttention(
embed_dim, num_heads, num_kv_heads, head_dim,
max_seq_len, dropout, rope_theta
)
self.ff = SwiGLUFeedForward(embed_dim, ff_hidden, dropout)
def forward(self, x):
x = x + self.attn(self.norm1(x))
x = x + self.ff(self.norm2(x))
return x
class JaneGPTv2Classifier(nn.Module):
"""
JaneGPT v2 Intent Classifier.
A decoder-only transformer with a classification head
for 22-class intent classification.
Args:
vocab_size: Vocabulary size (default: 8192)
embed_dim: Embedding dimension (default: 256)
num_heads: Number of attention heads (default: 8)
num_kv_heads: Number of KV heads for GQA (default: 4)
num_layers: Number of transformer layers (default: 8)
ff_hidden: Feed-forward hidden dimension (default: 672)
max_seq_len: Maximum sequence length (default: 256)
dropout: Dropout rate (default: 0.1)
rope_theta: RoPE theta parameter (default: 10000.0)
"""
def __init__(self, vocab_size=8192, embed_dim=256, num_heads=8,
num_kv_heads=4, num_layers=8, ff_hidden=672,
max_seq_len=256, dropout=0.1, rope_theta=10000.0):
super().__init__()
self.embed_dim = embed_dim
self.max_seq_len = max_seq_len
self.token_embedding = nn.Embedding(vocab_size, embed_dim)
head_dim = embed_dim // num_heads
self.layers = nn.ModuleList([
TransformerBlock(
embed_dim, num_heads, num_kv_heads, head_dim,
ff_hidden, max_seq_len, dropout, rope_theta
)
for _ in range(num_layers)
])
self.norm = RMSNorm(embed_dim)
self.dropout = nn.Dropout(dropout)
self.intent_head = nn.Sequential(
nn.Linear(embed_dim, embed_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(embed_dim, NUM_INTENTS),
)
def forward(self, x, labels=None):
x = self.dropout(self.token_embedding(x))
for layer in self.layers:
x = layer(x)
x = self.norm(x)
pooled = x[:, -1, :]
logits = self.intent_head(pooled)
loss = None
if labels is not None:
loss = F.cross_entropy(logits, labels)
return logits, loss
@torch.no_grad()
def predict(self, x):
logits, _ = self.forward(x)
probs = F.softmax(logits, dim=-1)
confidence, predicted = torch.max(probs, dim=-1)
return predicted.item(), confidence.item()