fix: correct Python API example in docs
Browse files
README.md
CHANGED
|
@@ -421,14 +421,31 @@ genesis --model ./genesis_152m_instruct.safetensors
|
|
| 421 |
### Python API
|
| 422 |
|
| 423 |
```python
|
| 424 |
-
|
| 425 |
-
|
| 426 |
-
|
| 427 |
-
|
| 428 |
-
|
| 429 |
-
|
| 430 |
-
|
| 431 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 432 |
prompt = """<|im_start|>system
|
| 433 |
You are a helpful assistant.
|
| 434 |
<|im_end|>
|
|
@@ -438,10 +455,12 @@ Explain what linear attention is in simple terms.
|
|
| 438 |
<|im_start|>assistant
|
| 439 |
"""
|
| 440 |
|
| 441 |
-
# Generate
|
| 442 |
-
input_ids = tokenizer.encode(prompt,
|
| 443 |
-
|
| 444 |
-
|
|
|
|
|
|
|
| 445 |
```
|
| 446 |
|
| 447 |
### Prompt Format
|
|
|
|
| 421 |
### Python API
|
| 422 |
|
| 423 |
```python
|
| 424 |
+
import json
|
| 425 |
+
import torch
|
| 426 |
+
from safetensors import safe_open
|
| 427 |
+
from safetensors.torch import load_file
|
| 428 |
+
from genesis import Genesis, GenesisConfig, get_tokenizer
|
| 429 |
+
|
| 430 |
+
# 1. Load config from checkpoint metadata
|
| 431 |
+
model_path = "./genesis_152m_instruct.safetensors"
|
| 432 |
+
with safe_open(model_path, framework="pt", device="cpu") as f:
|
| 433 |
+
metadata = f.metadata() or {}
|
| 434 |
+
config_dict = json.loads(metadata.get("genesis_config_json", "{}"))
|
| 435 |
+
config = GenesisConfig(**config_dict) if config_dict else GenesisConfig.genesis_147m()
|
| 436 |
+
|
| 437 |
+
# 2. Load model weights
|
| 438 |
+
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
|
| 439 |
+
state_dict = load_file(model_path, device=device)
|
| 440 |
+
model = Genesis(config).to(device)
|
| 441 |
+
model.load_state_dict(state_dict, strict=False)
|
| 442 |
+
model.eval()
|
| 443 |
+
|
| 444 |
+
# 3. Setup tokenizer (GPT-NeoX + ChatML tokens)
|
| 445 |
+
tokenizer = get_tokenizer("neox")
|
| 446 |
+
tokenizer.add_chat_tokens()
|
| 447 |
+
|
| 448 |
+
# 4. Build ChatML prompt
|
| 449 |
prompt = """<|im_start|>system
|
| 450 |
You are a helpful assistant.
|
| 451 |
<|im_end|>
|
|
|
|
| 455 |
<|im_start|>assistant
|
| 456 |
"""
|
| 457 |
|
| 458 |
+
# 5. Generate
|
| 459 |
+
input_ids = torch.tensor([tokenizer.encode(prompt)], device=device)
|
| 460 |
+
with torch.no_grad():
|
| 461 |
+
output_ids = model.generate(input_ids, max_new_tokens=256, temperature=0.7)
|
| 462 |
+
response = tokenizer.decode(output_ids[0][input_ids.shape[1]:].tolist())
|
| 463 |
+
print(response)
|
| 464 |
```
|
| 465 |
|
| 466 |
### Prompt Format
|