lennart-finke commited on
Commit
6555b2d
·
verified ·
1 Parent(s): 45914ee

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +7 -5
README.md CHANGED
@@ -35,9 +35,10 @@ model_size = "5M" # Options: "35M", "30M", "11M", "5M", "1.25M"
35
  model_config = MODEL_CONFIGS[model_size]
36
 
37
  # Load appropriate model
38
- model_path = f"chandan-sreedhara/SimpleStories-{model_size}"
39
  model = Llama.from_pretrained(model_path, model_config)
40
- model.to("cuda")
 
41
  model.eval()
42
 
43
  # Load tokenizer
@@ -47,14 +48,14 @@ tokenizer = AutoTokenizer.from_pretrained(model_path)
47
  prompt = "The curious cat looked at the"
48
 
49
  inputs = tokenizer(prompt, return_tensors="pt")
50
- input_ids = inputs.input_ids.to("cuda")
51
 
52
  # Generate text
53
  with torch.no_grad():
54
  output_ids = model.generate(
55
  idx=input_ids,
56
- max_new_tokens=800,
57
- temperature=0.7,
58
  top_k=40,
59
  eos_token_id=tokenizer.eos_token_id
60
  )
@@ -62,6 +63,7 @@ with torch.no_grad():
62
  # Decode output
63
  output_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
64
  print(f"Generated text:\n{output_text}")
 
65
  ```
66
 
67
  ## Model Variants
 
35
  model_config = MODEL_CONFIGS[model_size]
36
 
37
  # Load appropriate model
38
+ model_path = f"SimpleStories/SimpleStories-{model_size}"
39
  model = Llama.from_pretrained(model_path, model_config)
40
+ device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
41
+ model.to(device)
42
  model.eval()
43
 
44
  # Load tokenizer
 
48
  prompt = "The curious cat looked at the"
49
 
50
  inputs = tokenizer(prompt, return_tensors="pt")
51
+ input_ids = inputs.input_ids.to(device)
52
 
53
  # Generate text
54
  with torch.no_grad():
55
  output_ids = model.generate(
56
  idx=input_ids,
57
+ max_new_tokens=50,
58
+ temperature=0.0,
59
  top_k=40,
60
  eos_token_id=tokenizer.eos_token_id
61
  )
 
63
  # Decode output
64
  output_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
65
  print(f"Generated text:\n{output_text}")
66
+
67
  ```
68
 
69
  ## Model Variants