| """
|
| Display detailed information about JaneGPT v2 model.
|
|
|
| Shows architecture, parameters, training info, and size comparisons.
|
| """
|
|
|
| import os
|
| import torch
|
| from model.architecture import JaneGPTv2Classifier, INTENT_LABELS
|
|
|
|
|
| def main():
|
|
|
| checkpoint_path = "weights/janegpt_v2_classifier.pt"
|
| checkpoint = torch.load(checkpoint_path, map_location="cpu", weights_only=False)
|
| config = checkpoint.get('config', {})
|
|
|
|
|
| model = JaneGPTv2Classifier(
|
| vocab_size=config.get('vocab_size', 8192),
|
| embed_dim=config.get('embed_dim', 256),
|
| num_heads=config.get('num_heads', 8),
|
| num_kv_heads=config.get('num_kv_heads', 4),
|
| num_layers=config.get('num_layers', 8),
|
| ff_hidden=config.get('ff_hidden', 672),
|
| max_seq_len=config.get('max_seq_len', 256),
|
| dropout=config.get('dropout', 0.1),
|
| rope_theta=config.get('rope_theta', 10000.0),
|
| )
|
| model.load_state_dict(checkpoint['model_state_dict'])
|
|
|
|
|
| total_params = sum(p.numel() for p in model.parameters())
|
| trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
| buffers = sum(b.numel() for b in model.buffers())
|
|
|
| print("=" * 60)
|
| print(" JANEGPT v2 - MODEL INFORMATION")
|
| print("=" * 60)
|
|
|
|
|
| print("\n ARCHITECTURE")
|
| print(f" Type: Decoder-only Transformer (Classifier)")
|
| print(f" Vocab Size: {config.get('vocab_size', 8192):,}")
|
| print(f" Embedding Dim: {config.get('embed_dim', 256)}")
|
| print(f" Attention Heads: {config.get('num_heads', 8)}")
|
| print(f" KV Heads (GQA): {config.get('num_kv_heads', 4)}")
|
| print(f" Head Dim: {config.get('embed_dim', 256) // config.get('num_heads', 8)}")
|
| print(f" Layers: {config.get('num_layers', 8)}")
|
| print(f" FF Hidden: {config.get('ff_hidden', 672)}")
|
| print(f" Max Seq Length: {config.get('max_seq_len', 256)}")
|
| print(f" Dropout: {config.get('dropout', 0.1)}")
|
| print(f" RoPE Theta: {config.get('rope_theta', 10000.0)}")
|
|
|
|
|
| print("\n FEATURES")
|
| print(f" Position Encoding: RoPE (Rotary Position Embedding)")
|
| print(f" Normalization: RMSNorm")
|
| print(f" Attention: Grouped Query Attention (GQA)")
|
| print(f" Feed-Forward: SwiGLU")
|
| print(f" Classifier Head: Linear -> GELU -> Dropout -> Linear")
|
| print(f" Output Classes: {len(INTENT_LABELS)}")
|
|
|
|
|
| print("\n PARAMETERS")
|
| print(f" Total Parameters: {total_params:>12,}")
|
| print(f" Trainable Parameters: {trainable_params:>12,}")
|
| print(f" Non-trainable Buffers: {buffers:>12,}")
|
| print(f" Model Size (float32): {total_params * 4 / 1024 / 1024:.2f} MB")
|
| print(f" Model Size (float16): {total_params * 2 / 1024 / 1024:.2f} MB")
|
|
|
|
|
| print("\n PARAMETER BREAKDOWN")
|
| print(f" {'Component':<35} {'Params':>12} {'%':>8}")
|
| print(f" {'-' * 55}")
|
|
|
| emb_params = sum(p.numel() for p in model.token_embedding.parameters())
|
| print(f" {'Token Embedding':<35} {emb_params:>12,} {emb_params/total_params*100:>7.1f}%")
|
|
|
| all_layers_params = sum(p.numel() for p in model.layers.parameters())
|
| print(f" {'Transformer Layers (total)':<35} {all_layers_params:>12,} {all_layers_params/total_params*100:>7.1f}%")
|
|
|
|
|
| layer0_params = sum(p.numel() for p in model.layers[0].parameters())
|
| attn_params = sum(p.numel() for p in model.layers[0].attn.parameters()) - sum(
|
| b.numel() for b in model.layers[0].attn.buffers()
|
| )
|
| ff_params = sum(p.numel() for p in model.layers[0].ff.parameters())
|
| norm_params = model.layers[0].norm1.weight.numel() + model.layers[0].norm2.weight.numel()
|
|
|
| print(f" {' Per layer (x8):':<33} {layer0_params:>12,}")
|
| print(f" {' Attention (Q/K/V/Out)':<33} {attn_params:>12,}")
|
| print(f" {' Feed-Forward (SwiGLU)':<33} {ff_params:>12,}")
|
| print(f" {' Norms (RMSNorm x2)':<33} {norm_params:>12,}")
|
|
|
| final_norm_params = model.norm.weight.numel()
|
| print(f" {'Final RMSNorm':<35} {final_norm_params:>12,} {final_norm_params/total_params*100:>7.1f}%")
|
|
|
| head_params = sum(p.numel() for p in model.intent_head.parameters())
|
| print(f" {'Classification Head':<35} {head_params:>12,} {head_params/total_params*100:>7.1f}%")
|
| print(f" {' Linear(256, 256) + bias':<33} {256 * 256 + 256:>12,}")
|
| print(f" {' Linear(256, 22) + bias':<33} {256 * 22 + 22:>12,}")
|
|
|
|
|
| print("\n TRAINING")
|
| print(f" Best Val Accuracy: {checkpoint.get('val_acc', 0):.2f}%")
|
| print(f" Best Val Loss: {checkpoint.get('val_loss', 0):.4f}")
|
| print(f" Best Epoch: {checkpoint.get('epoch', 'N/A')}")
|
|
|
|
|
| print(f"\n INTENT CLASSES ({len(INTENT_LABELS)})")
|
| for i, label in enumerate(INTENT_LABELS):
|
| print(f" {i:>2}: {label}")
|
|
|
|
|
| print(f"\n FILES")
|
| if os.path.exists(checkpoint_path):
|
| model_size = os.path.getsize(checkpoint_path)
|
| print(f" Checkpoint: {model_size / 1024 / 1024:.2f} MB")
|
|
|
| tokenizer_path = "weights/tokenizer.json"
|
| if os.path.exists(tokenizer_path):
|
| tok_size = os.path.getsize(tokenizer_path)
|
| print(f" Tokenizer: {tok_size / 1024:.1f} KB")
|
|
|
|
|
| print(f"\n SIZE COMPARISON")
|
| print(f" {'Model':<30} {'Parameters':>15} {'Size':>10}")
|
| print(f" {'-' * 55}")
|
| print(f" {'JaneGPT v2 (this model)':<30} {total_params:>12,} {total_params * 4 / 1024 / 1024:>5.1f} MB")
|
| print(f" {'DistilBERT':<30} {'66,000,000':>15} {'260.0 MB':>10}")
|
| print(f" {'BERT Base':<30} {'110,000,000':>15} {'440.0 MB':>10}")
|
| print(f" {'GPT-2 Small':<30} {'124,000,000':>15} {'500.0 MB':>10}")
|
| print(f" {'Llama 3 8B':<30} {'8,000,000,000':>15} {' 16.0 GB':>10}")
|
| print(f" {'GPT-4':<30} {'~1,800,000,000,000':>15} {'~ 3.6 TB':>10}")
|
|
|
| print(f"\n Created by: Ravindu Senanayake")
|
| print("=" * 60)
|
|
|
|
|
| if __name__ == "__main__":
|
| main() |