Girinath11 commited on
Commit
82a94c2
Β·
verified Β·
1 Parent(s): 2caeef3

Update model_usage.py

Browse files
Files changed (1) hide show
  1. model_usage.py +3 -25
model_usage.py CHANGED
@@ -1,9 +1,6 @@
1
- from mixture_of_recursion import RecursiveLanguageModel, RecursiveLanguageModelConfig
2
- from transformers import AutoTokenizer
3
  import torch
4
-
5
- # Load model
6
- model = RecursiveLanguageModel.from_pretrained(
7
  "Girinath11/recursive-language-model-198m",
8
  trust_remote_code=True
9
  )
@@ -11,25 +8,18 @@ tokenizer = AutoTokenizer.from_pretrained(
11
  "Girinath11/recursive-language-model-198m",
12
  trust_remote_code=True
13
  )
14
-
15
  device = "cuda" if torch.cuda.is_available() else "cpu"
16
  model = model.to(device)
17
  model.eval()
18
-
19
  print(f"βœ… Model loaded on {device}")
20
  print(f"πŸ“Š Parameters: {sum(p.numel() for p in model.parameters()):,}\n")
21
-
22
-
23
  def chat(question, max_new_tokens=150, temperature=0.7, top_p=0.9):
24
- # Must use chat format β€” model trained on this
25
  prompt = f"<|user|>\n{question}\n<|assistant|>\n"
26
-
27
  inputs = tokenizer(
28
  prompt,
29
  return_tensors="pt",
30
  add_special_tokens=False
31
  ).to(device)
32
-
33
  with torch.no_grad():
34
  outputs = model.generate(
35
  inputs['input_ids'],
@@ -38,31 +28,19 @@ def chat(question, max_new_tokens=150, temperature=0.7, top_p=0.9):
38
  top_p=top_p,
39
  do_sample=True,
40
  )
41
-
42
  full_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
43
-
44
- # Extract only assistant response
45
  if "<|assistant|>" in full_text:
46
  response = full_text.split("<|assistant|>")[-1].strip()
47
  else:
48
  response = full_text.replace(question, "").strip()
49
 
50
  return response
51
-
52
-
53
- # Test
54
  questions = [
55
  "What is machine learning?",
56
  "What is Python programming?",
57
  "Explain neural networks simply",
58
  "What is artificial intelligence?",
59
  ]
60
-
61
- print("=" * 55)
62
- print("πŸ€– MIXTURE OF RECURSION LM β€” 198M")
63
- print("=" * 55)
64
-
65
  for q in questions:
66
  print(f"\n❓ {q}")
67
- print(f"πŸ’¬ {chat(q)}")
68
- print("-" * 55)
 
1
+ from transformers import AutoModelForCausalLM, AutoTokenizer
 
2
  import torch
3
+ model = AutoModelForCausalLM.from_pretrained(
 
 
4
  "Girinath11/recursive-language-model-198m",
5
  trust_remote_code=True
6
  )
 
8
  "Girinath11/recursive-language-model-198m",
9
  trust_remote_code=True
10
  )
 
11
  device = "cuda" if torch.cuda.is_available() else "cpu"
12
  model = model.to(device)
13
  model.eval()
 
14
  print(f"βœ… Model loaded on {device}")
15
  print(f"πŸ“Š Parameters: {sum(p.numel() for p in model.parameters()):,}\n")
 
 
16
  def chat(question, max_new_tokens=150, temperature=0.7, top_p=0.9):
 
17
  prompt = f"<|user|>\n{question}\n<|assistant|>\n"
 
18
  inputs = tokenizer(
19
  prompt,
20
  return_tensors="pt",
21
  add_special_tokens=False
22
  ).to(device)
 
23
  with torch.no_grad():
24
  outputs = model.generate(
25
  inputs['input_ids'],
 
28
  top_p=top_p,
29
  do_sample=True,
30
  )
 
31
  full_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
 
 
32
  if "<|assistant|>" in full_text:
33
  response = full_text.split("<|assistant|>")[-1].strip()
34
  else:
35
  response = full_text.replace(question, "").strip()
36
 
37
  return response
 
 
 
38
  questions = [
39
  "What is machine learning?",
40
  "What is Python programming?",
41
  "Explain neural networks simply",
42
  "What is artificial intelligence?",
43
  ]
 
 
 
 
 
44
  for q in questions:
45
  print(f"\n❓ {q}")
46
+ print(f"πŸ’¬ {chat(q)}")