Keeby-smilyai commited on
Commit
082a7b8
·
verified ·
1 Parent(s): c795bfe

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +95 -3
README.md CHANGED
@@ -1,3 +1,95 @@
1
- ---
2
- license: mit
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: mit
3
+ datasets:
4
+ - pfb30/multi_woz_v22
5
+ language:
6
+ - en
7
+ pipeline_tag: text-generation
8
+ library_name: transformers
9
+ tags:
10
+ - Sam-2
11
+ - text-generation
12
+ ---
13
+
14
+ # 🧠 Model Card: Sam‑2.5
15
+
16
+ ## 📌 Model Overview
17
+ **Sam‑2.5** is a minimal, modular, decoder‑only Transformer architecture designed for chat‑style reasoning tasks.
18
+ It emphasizes reproducibility, ablation‑friendly design, and clean benchmarking across input modalities.
19
+
20
+ - **Architecture**: Decoder‑only Transformer with RMSNorm, SwiGLU feed‑forward, and causal masking
21
+ - **Training Objective**: Causal language modeling (CLM) with role‑based label masking
22
+ - **Checkpoint**: `sam2-epoch35.safetensors`
23
+ - **Final Train Loss**: 1.04
24
+ - **Validation Loss**: Not tracked in this run
25
+ - **Training Duration**: ~6272 s over 35 epochs
26
+ - **Framework**: PyTorch + Hugging Face Transformers (custom model class)
27
+
28
+ ## 🧱 Model Architecture
29
+ | Component | Description |
30
+ |-------------------|-----------------------------------------------------------------------------|
31
+ | Backbone | Decoder‑only Transformer stack |
32
+ | Normalization | RMSNorm |
33
+ | Attention | Multi‑head self‑attention (causal) |
34
+ | Feed‑Forward | SwiGLU activation with dropout |
35
+ | Positional Bias | Learned absolute positions (no RoPE in this minimal variant) |
36
+ | Head | Tied‑embedding LM head |
37
+ | Checkpoint Format | `safetensors` with metadata for reproducibility |
38
+
39
+ ## 🧪 Training Details
40
+ - **Dataset**: [pfb30/multi_woz_v22](https://huggingface.co/datasets/pfb30/multi_woz_v22)
41
+ - **Batch Size**: 8
42
+ - **Optimizer**: AdamW
43
+ - **Learning Rate**: 2 × 10⁻⁴ (constant in this run)
44
+ - **Loss Function**: Cross‑entropy over assistant tokens only
45
+ - **Hardware**: Kaggle GPU runtime
46
+ - **Logging**: Step‑wise loss tracking, no validation during training
47
+
48
+ ## 📊 Evaluation
49
+ | Metric | Value | Notes |
50
+ |------------------|-------------|---------------------------------------|
51
+ | Final Train Loss | 1.04 | Achieved at Epoch 35/35 |
52
+ | Validation Loss | 1.9554 |
53
+ | Inference Speed | Fast | Lightweight architecture |
54
+ | Generalisation | TBD | To be compared against Sam‑2.5 |
55
+
56
+ ## 🔧 Intended Use
57
+ - **Research**: Benchmarking modular architectures and ablation studies
58
+ - **Education**: Reasoning scaffolds and logic quizzes
59
+ - **Deployment**: Lightweight agents for chat and dialogue modeling
60
+
61
+ ## 🚫 Limitations
62
+ - No validation tracking — generalisation must be inferred via external harnesses
63
+ - Trained on MultiWOZ v2.2 only — may not generalize to other domains without fine‑tuning
64
+ - Minimal architecture — no RoPE/MQA in this variant
65
+
66
+ ## 📁 Files
67
+ - `sam2-epoch35.safetensors` — final checkpoint
68
+ - `config.json` — architecture and training config
69
+ - `tokenizer.json` — tokenizer with special tokens
70
+ - `README.md` — training logs and setup instructions
71
+
72
+ ## 🧩 How to Load
73
+ ```python
74
+ from transformers import AutoTokenizer
75
+ import torch
76
+ from sam2 import Sam2, Sam2Config # your custom model class
77
+
78
+ tok = AutoTokenizer.from_pretrained("Smilyai-labs/Sam-2.0")
79
+ cfg = Sam2Config(**json.load(open("config.json")))
80
+ model = Sam2(cfg)
81
+ state = torch.load("sam2-epoch35.safetensors", map_location="cpu")
82
+ model.load_state_dict(state)
83
+ model.eval()
84
+
85
+ prompt = "<|user|> Hello! <|eot|>\n<|assistant|>"
86
+ ids = tok.encode(prompt, return_tensors="pt")
87
+ with torch.no_grad():
88
+ for _ in range(50):
89
+ logits = model(ids)
90
+ next_id = torch.argmax(logits[:, -1, :], dim=-1, keepdim=True)
91
+ ids = torch.cat([ids, next_id], dim=1)
92
+ if next_id.item() == tok.eos_token_id:
93
+ break
94
+
95
+ print(tok.decode(ids[0]))