convaiinnovations commited on
Commit
e91b2cb
·
verified ·
1 Parent(s): efaf053

Upload interactive_inference.py

Browse files
Files changed (1) hide show
  1. interactive_inference.py +14 -6
interactive_inference.py CHANGED
@@ -65,7 +65,10 @@ def interactive_session():
65
  # We enable modulation to see the effect of the trained controller
66
  # The controller predicts modulation based on the input prompt
67
 
68
- inputs = model.tokenizer(user_input, return_tensors="pt").to(device)
 
 
 
69
 
70
  with torch.no_grad():
71
  # 1. Predict Modulation
@@ -78,17 +81,22 @@ def interactive_session():
78
  **inputs,
79
  max_new_tokens=128,
80
  do_sample=True,
81
- temperature=0.7,
82
- repetition_penalty=1.1,
 
83
  pad_token_id=model.tokenizer.eos_token_id
84
  )
85
 
86
  model.clear_modulation()
87
 
88
  response = model.tokenizer.decode(out_ids[0], skip_special_tokens=True)
89
- # Strip prompt if included (Gemma usually includes it)
90
- if response.startswith(user_input):
91
- response = response[len(user_input):].strip()
 
 
 
 
92
 
93
  print(f"MODEL: {response}")
94
  print(f" [Modulation Norm: {torch.norm(modulation).item():.2f}]")
 
65
  # We enable modulation to see the effect of the trained controller
66
  # The controller predicts modulation based on the input prompt
67
 
68
+ # Format the prompt to match training distribution
69
+ prompt = f"User: {user_input}\nModel: "
70
+
71
+ inputs = model.tokenizer(prompt, return_tensors="pt").to(device)
72
 
73
  with torch.no_grad():
74
  # 1. Predict Modulation
 
81
  **inputs,
82
  max_new_tokens=128,
83
  do_sample=True,
84
+ temperature=0.6,
85
+ top_p=0.9,
86
+ repetition_penalty=1.2,
87
  pad_token_id=model.tokenizer.eos_token_id
88
  )
89
 
90
  model.clear_modulation()
91
 
92
  response = model.tokenizer.decode(out_ids[0], skip_special_tokens=True)
93
+
94
+ # Clean up response (Remove the prompt part)
95
+ if response.startswith(prompt):
96
+ response = response[len(prompt):].strip()
97
+ elif "Model:" in response:
98
+ response = response.split("Model:")[-1].strip()
99
+
100
 
101
  print(f"MODEL: {response}")
102
  print(f" [Modulation Norm: {torch.norm(modulation).item():.2f}]")