Trouter-20b / USAGE_GUIDE.md
Trouter-Library's picture
Create USAGE_GUIDE.md
c11b4a4 verified
# Trouter-20B Usage Guide
## Installation
```bash
pip install transformers torch accelerate bitsandbytes
```
## Quick Start
### Basic Text Generation
```python
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
# Load model and tokenizer
model_name = "your-username/Trouter-20B"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.bfloat16,
device_map="auto"
)
# Generate text
prompt = "Explain quantum computing in simple terms:"
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
outputs = model.generate(
**inputs,
max_new_tokens=256,
temperature=0.7,
top_p=0.95,
do_sample=True
)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
print(response)
```
### Chat Interface
```python
def chat(messages, max_new_tokens=512):
"""
Chat with the model using a conversation history.
Args:
messages: List of dicts with 'role' and 'content' keys
max_new_tokens: Maximum tokens to generate
Example:
messages = [
{"role": "user", "content": "What is machine learning?"}
]
"""
prompt = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
outputs = model.generate(
**inputs,
max_new_tokens=max_new_tokens,
temperature=0.7,
top_p=0.95,
do_sample=True,
pad_token_id=tokenizer.eos_token_id
)
response = tokenizer.decode(outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)
return response
# Example usage
conversation = [
{"role": "user", "content": "Hello! Can you help me with Python?"}
]
response = chat(conversation)
print(response)
# Continue conversation
conversation.append({"role": "assistant", "content": response})
conversation.append({"role": "user", "content": "Show me how to read a CSV file."})
response = chat(conversation)
print(response)
```
### Memory-Efficient Loading (8-bit Quantization)
```python
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
model_name = "your-username/Trouter-20B"
# Load in 8-bit for reduced memory usage
model = AutoModelForCausalLM.from_pretrained(
model_name,
load_in_8bit=True,
device_map="auto",
torch_dtype=torch.float16
)
tokenizer = AutoTokenizer.from_pretrained(model_name)
```
### 4-bit Quantization (Even Lower Memory)
```python
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
model_name = "your-username/Trouter-20B"
# Configure 4-bit quantization
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=True
)
model = AutoModelForCausalLM.from_pretrained(
model_name,
quantization_config=bnb_config,
device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained(model_name)
```
## Advanced Usage
### Batch Generation
```python
prompts = [
"Write a poem about AI:",
"Explain neural networks:",
"What is reinforcement learning?"
]
inputs = tokenizer(prompts, return_tensors="pt", padding=True).to(model.device)
outputs = model.generate(
**inputs,
max_new_tokens=128,
temperature=0.8,
top_p=0.95,
num_return_sequences=1,
do_sample=True,
pad_token_id=tokenizer.eos_token_id
)
responses = tokenizer.batch_decode(outputs, skip_special_tokens=True)
for prompt, response in zip(prompts, responses):
print(f"Prompt: {prompt}")
print(f"Response: {response}\n")
```
### Streaming Generation
```python
from transformers import TextIteratorStreamer
from threading import Thread
streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True)
prompt = "Write a story about a robot:"
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
generation_kwargs = {
**inputs,
"max_new_tokens": 256,
"temperature": 0.7,
"do_sample": True,
"streamer": streamer
}
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
print("Generated text: ", end="")
for new_text in streamer:
print(new_text, end="", flush=True)
print()
```
### Custom Generation Parameters
```python
# Creative generation
creative_output = model.generate(
**inputs,
max_new_tokens=256,
temperature=1.0, # Higher = more creative
top_p=0.95,
top_k=50,
repetition_penalty=1.2,
do_sample=True
)
# Deterministic generation
deterministic_output = model.generate(
**inputs,
max_new_tokens=256,
temperature=0.1, # Lower = more focused
do_sample=False,
num_beams=4 # Beam search for quality
)
```
## Fine-tuning
### Using PEFT (Parameter-Efficient Fine-Tuning)
```python
from peft import LoraConfig, get_peft_model
from transformers import TrainingArguments, Trainer
# Configure LoRA
lora_config = LoraConfig(
r=16,
lora_alpha=32,
target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM"
)
# Apply LoRA to model
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()
# Training arguments
training_args = TrainingArguments(
output_dir="./trouter-finetuned",
per_device_train_batch_size=4,
gradient_accumulation_steps=4,
learning_rate=2e-4,
num_train_epochs=3,
logging_steps=10,
save_steps=100,
fp16=True
)
# Train
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset
)
trainer.train()
```
## Performance Optimization
### GPU Memory Requirements
- **Full precision (bfloat16)**: ~40GB VRAM
- **8-bit quantization**: ~20GB VRAM
- **4-bit quantization**: ~10GB VRAM
### Recommendations
- Use `device_map="auto"` for automatic multi-GPU distribution
- Enable `torch.compile()` for PyTorch 2.0+ for faster inference
- Use Flash Attention 2 if available for better performance
```python
# Enable Flash Attention 2
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.bfloat16,
device_map="auto",
attn_implementation="flash_attention_2"
)
```
## Troubleshooting
### Out of Memory Errors
1. Use quantization (8-bit or 4-bit)
2. Reduce `max_new_tokens`
3. Decrease batch size
4. Enable gradient checkpointing for fine-tuning
### Slow Generation
1. Use smaller `max_new_tokens`
2. Disable `do_sample` for greedy decoding
3. Use Flash Attention 2
4. Consider model quantization
### Poor Quality Outputs
1. Adjust temperature (0.7-0.9 recommended)
2. Tune top_p and top_k values
3. Add repetition_penalty (1.1-1.3)
4. Ensure proper prompt formatting
## Community and Support
- **Issues**: [GitHub Issues](https://github.com/your-username/Trouter-20B/issues)
- **Discussions**: [Hugging Face Discussions](https://huggingface.co/your-username/Trouter-20B/discussions)
- **Discord**: [Community Discord](#)
## Citation
If you use Trouter-20B in your research, please cite:
```bibtex
@software{trouter20b2025,
title={Trouter-20B: A 20 Billion Parameter Language Model},
author={Your Name},
year={2025},
url={https://huggingface.co/your-username/Trouter-20B}
}
```