RetentionLabs commited on
Commit
54b7585
·
verified ·
1 Parent(s): 8b03813

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +5 -4
README.md CHANGED
@@ -59,15 +59,16 @@ model_id = "RetentionLabs/TTT-Linear-1.3B-Base-Books-32k"
59
  tokenizer = AutoTokenizer.from_pretrained(model_id)
60
  model = AutoModelForCausalLM.from_pretrained(
61
  model_id,
62
- torch_dtype=torch.bfloat16,
63
  trust_remote_code=True,
 
64
  device_map="auto"
65
  )
66
 
67
  # Generate
68
- inputs = tokenizer("The future of AI is", return_tensors="pt").to(model.device)
69
- outputs = model.generate(**inputs, max_new_tokens=100)
70
- print(tokenizer.decode(outputs[0], skip_special_tokens=True))
 
71
  ```
72
 
73
  ### From scratch
 
59
  tokenizer = AutoTokenizer.from_pretrained(model_id)
60
  model = AutoModelForCausalLM.from_pretrained(
61
  model_id,
 
62
  trust_remote_code=True,
63
+ dtype=torch.bfloat16,
64
  device_map="auto"
65
  )
66
 
67
  # Generate
68
+ with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
69
+ inputs = tokenizer("The future of AI is", return_tensors="pt").to(model.device)
70
+ outputs = model.generate(**inputs, max_new_tokens=100)
71
+ print(tokenizer.decode(outputs[0], skip_special_tokens=True))
72
  ```
73
 
74
  ### From scratch