guiferrarib commited on
Commit
6ff11d8
·
verified ·
1 Parent(s): b34c50d

fix: correct Python API example in docs

Browse files
Files changed (1) hide show
  1. README.md +31 -12
README.md CHANGED
@@ -421,14 +421,31 @@ genesis --model ./genesis_152m_instruct.safetensors
421
  ### Python API
422
 
423
  ```python
424
- from genesis import Genesis, GenesisConfig
425
- from genesis.tokenizer import GenesisTokenizer
426
-
427
- # Load model
428
- model = Genesis.from_pretrained("./genesis_152m_instruct.safetensors")
429
- tokenizer = GenesisTokenizer()
430
-
431
- # ChatML format
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
432
  prompt = """<|im_start|>system
433
  You are a helpful assistant.
434
  <|im_end|>
@@ -438,10 +455,12 @@ Explain what linear attention is in simple terms.
438
  <|im_start|>assistant
439
  """
440
 
441
- # Generate
442
- input_ids = tokenizer.encode(prompt, return_tensors="pt")
443
- output = model.generate(input_ids, max_new_tokens=256, temperature=0.7)
444
- print(tokenizer.decode(output[0]))
 
 
445
  ```
446
 
447
  ### Prompt Format
 
421
  ### Python API
422
 
423
  ```python
424
+ import json
425
+ import torch
426
+ from safetensors import safe_open
427
+ from safetensors.torch import load_file
428
+ from genesis import Genesis, GenesisConfig, get_tokenizer
429
+
430
+ # 1. Load config from checkpoint metadata
431
+ model_path = "./genesis_152m_instruct.safetensors"
432
+ with safe_open(model_path, framework="pt", device="cpu") as f:
433
+ metadata = f.metadata() or {}
434
+ config_dict = json.loads(metadata.get("genesis_config_json", "{}"))
435
+ config = GenesisConfig(**config_dict) if config_dict else GenesisConfig.genesis_147m()
436
+
437
+ # 2. Load model weights
438
+ device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
439
+ state_dict = load_file(model_path, device=device)
440
+ model = Genesis(config).to(device)
441
+ model.load_state_dict(state_dict, strict=False)
442
+ model.eval()
443
+
444
+ # 3. Setup tokenizer (GPT-NeoX + ChatML tokens)
445
+ tokenizer = get_tokenizer("neox")
446
+ tokenizer.add_chat_tokens()
447
+
448
+ # 4. Build ChatML prompt
449
  prompt = """<|im_start|>system
450
  You are a helpful assistant.
451
  <|im_end|>
 
455
  <|im_start|>assistant
456
  """
457
 
458
+ # 5. Generate
459
+ input_ids = torch.tensor([tokenizer.encode(prompt)], device=device)
460
+ with torch.no_grad():
461
+ output_ids = model.generate(input_ids, max_new_tokens=256, temperature=0.7)
462
+ response = tokenizer.decode(output_ids[0][input_ids.shape[1]:].tolist())
463
+ print(response)
464
  ```
465
 
466
  ### Prompt Format