agiformer / docs /inference.md
tefoteknik's picture
Update AGIFORMER with Turkish benchmark
58413f0 verified
# Inference Guide
## Quick Start
```bash
python generate.py
```
**Default Output:**
```
Prompt: 'The history of '
--------------------------------------------------
The history of Tomadination of the [[New Gouple de aparty]]...
```
---
## Basic Usage
### 1. Load Model
```python
from src.models.agiformer import AGIFORMER
import torch
model = AGIFORMER(d_model=512, n_layers=6, patch_size=4, thinking_steps=3)
model.load_state_dict(torch.load("best_model.pth"))
model.eval()
```
### 2. Prepare Input
```python
prompt = "The history of artificial intelligence"
input_bytes = [ord(c) for c in prompt]
# Pad to patch_size boundary
pad = (4 - len(input_bytes) % 4) % 4
input_bytes.extend([32] * pad)
x = torch.tensor(input_bytes).unsqueeze(0) # (1, seq_len)
```
### 3. Generate
```python
with torch.no_grad():
output = model(x, temperature=0.7) # (1, num_patches, patch_size)
# Decode
generated_bytes = output[0, -1, :].tolist()
text = ''.join([chr(b) for b in generated_bytes if 32 <= b <= 126])
```
---
## Temperature Sampling
### Greedy (Temperature = 0)
```python
output = model(x, temperature=0.0)
```
- Picks most likely byte every step
- **Deterministic** (same output each run)
- Prone to repetition loops
**Example:**
```
The history of of of of of...
```
### Low Temperature (0.3 - 0.5)
```python
output = model(x, temperature=0.3)
```
- Slightly random, still conservative
- Good for **coherent** text
- Reduces repetition
**Example:**
```
The history of the computer system...
```
### Medium Temperature (0.7 - 0.9)
```python
output = model(x, temperature=0.7) # Default
```
- Balanced creativity/coherence
- **Recommended** for exploration
**Example:**
```
The history of Tomadination of the [[New Gouple]]...
```
### High Temperature (1.0+)
```python
output = model(x, temperature=1.2)
```
- Very random
- Incoherent but diverse
- Good for **debugging** model knowledge
**Example:**
```
The history qw8#$x [[zap]] nullification...
```
---
## Advanced: Token-by-Token Generation
For streaming output:
```python
def generate_stream(model, prompt, max_tokens=200, temperature=0.7):
# Encode prompt
context = [ord(c) for c in prompt]
pad = (4 - len(context) % 4) % 4
context.extend([32] * pad)
for _ in range(max_tokens // 4): # Generate patch-by-patch
x = torch.tensor(context[-1024:]).unsqueeze(0) # Sliding window
with torch.no_grad():
pred = model(x, temperature=temperature)
# Get last patch
new_bytes = pred[0, -1, :].cpu().tolist()
context.extend(new_bytes)
# Decode and print
chunk = ''.join([chr(b) for b in new_bytes if 32 <= b <= 126])
print(chunk, end='', flush=True)
```
**Usage:**
```python
generate_stream(model, "The history of ", max_tokens=200)
```
---
## System 2 Control
### Disable Thinking (Baseline)
```python
model = AGIFORMER(thinking_steps=0) # Skip System 2
```
- Faster inference (~2× speedup)
- Lower quality output
### Increase Thinking
```python
model = AGIFORMER(thinking_steps=5) # More refinement
```
- Slower inference
- Potentially better coherence
### Runtime Control
System 2 is part of the model, so you must reinitialize:
```python
# Not possible to change thinking_steps after model creation
# Must create new model with desired config
```
---
## Batch Inference
Process multiple prompts:
```python
prompts = ["The history of ", "In the year 2050, ", "Once upon a time, "]
batch = []
for prompt in prompts:
bytes = [ord(c) for c in prompt]
pad = (4 - len(bytes) % 4) % 4
bytes.extend([32] * pad)
batch.append(torch.tensor(bytes))
# Pad to same length
max_len = max(t.size(0) for t in batch)
batch_tensor = torch.stack([
F.pad(t, (0, max_len - t.size(0)))
for t in batch
])
# Generate
with torch.no_grad():
outputs = model(batch_tensor, temperature=0.7)
```
---
## Debugging Output
### Check Raw Bytes
```python
generated = model(x, temperature=0.0)
raw_bytes = generated[0, -1, :].tolist()
print(f"Raw: {raw_bytes}") # e.g., [116, 104, 101, 32]
```
### Detect Non-Printables
```python
for b in raw_bytes:
if not (32 <= b <= 126):
print(f"Warning: Non-ASCII byte {b}")
```
### Measure Entropy
```python
import torch.nn.functional as F
logits = model.head(latents) # Get raw logits
probs = F.softmax(logits, dim=-1)
entropy = -(probs * torch.log(probs + 1e-10)).sum(dim=-1).mean()
print(f"Avg Entropy: {entropy.item():.2f} bits")
# Low (<2): Confident, may repeat
# High (>6): Confused, will be random
```
---
## Common Issues
### Repetition Loops
**Problem:**
```
of of of of of...
```
**Solutions:**
1. Increase temperature: `0.0 → 0.7`
2. Use nucleus sampling (top-p):
```python
probs = F.softmax(logits / temp, dim=-1)
sorted_probs, indices = torch.sort(probs, descending=True)
cumsum = torch.cumsum(sorted_probs, dim=-1)
mask = cumsum > 0.9 # Keep top 90%
sorted_probs[mask] = 0
next_byte = torch.multinomial(sorted_probs, 1)
```
### Gibberish Output
**Problem:**
```
xq#$8z [[nullification]]...
```
**Causes:**
- Temperature too high
- Model undertrained
**Solutions:**
- Lower temperature: `1.2 → 0.5`
- Train longer (20k+ steps)
### Slow Inference
**Problem:** >1s per token
**Solutions:**
- Use GPU: `model.cuda()`
- Reduce `thinking_steps`: `3 → 1`
- Disable System 2: `thinking_steps=0`
---
## Performance Benchmarks
**GPU:** NVIDIA T4
**Prompt Length:** 100 bytes
**Generation Length:** 200 bytes
| Config | Latency | Throughput |
|--------|---------|------------|
| Greedy (temp=0) | 45ms | 22 tokens/s |
| Sampling (temp=0.7) | 52ms | 19 tokens/s |
| System 2 disabled | 28ms | 36 tokens/s |
---
## API Reference
### Model Forward
```python
def forward(
x: torch.Tensor, # (Batch, Seq_Len) bytes
target_bytes: Optional[torch.Tensor] = None, # For training
temperature: float = 0.0 # Sampling temp (0 = greedy)
) -> torch.Tensor:
# Returns: (Batch, Num_Patches, Patch_Size, 256) if training
# (Batch, Num_Patches, Patch_Size) if inference
```
### Generation Utilities
See `generate.py` for full implementation:
- `generate_text(model_path, prompt, max_tokens, temperature)`
- Automatic padding and decoding
---
## Next Steps
1. **Experiment with Prompts:** Try different domains
2. **Tune Temperature:** Find sweet spot for your use case
3. **Extend Context:** Modify `generate.py` to use longer contexts
4. **Fine-tune:** Retrain on domain-specific data