Upload interactive_inference.py
Browse files- interactive_inference.py +14 -6
interactive_inference.py
CHANGED
|
@@ -65,7 +65,10 @@ def interactive_session():
|
|
| 65 |
# We enable modulation to see the effect of the trained controller
|
| 66 |
# The controller predicts modulation based on the input prompt
|
| 67 |
|
| 68 |
-
|
|
|
|
|
|
|
|
|
|
| 69 |
|
| 70 |
with torch.no_grad():
|
| 71 |
# 1. Predict Modulation
|
|
@@ -78,17 +81,22 @@ def interactive_session():
|
|
| 78 |
**inputs,
|
| 79 |
max_new_tokens=128,
|
| 80 |
do_sample=True,
|
| 81 |
-
temperature=0.
|
| 82 |
-
|
|
|
|
| 83 |
pad_token_id=model.tokenizer.eos_token_id
|
| 84 |
)
|
| 85 |
|
| 86 |
model.clear_modulation()
|
| 87 |
|
| 88 |
response = model.tokenizer.decode(out_ids[0], skip_special_tokens=True)
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 92 |
|
| 93 |
print(f"MODEL: {response}")
|
| 94 |
print(f" [Modulation Norm: {torch.norm(modulation).item():.2f}]")
|
|
|
|
| 65 |
# We enable modulation to see the effect of the trained controller
|
| 66 |
# The controller predicts modulation based on the input prompt
|
| 67 |
|
| 68 |
+
# Format the prompt to match training distribution
|
| 69 |
+
prompt = f"User: {user_input}\nModel: "
|
| 70 |
+
|
| 71 |
+
inputs = model.tokenizer(prompt, return_tensors="pt").to(device)
|
| 72 |
|
| 73 |
with torch.no_grad():
|
| 74 |
# 1. Predict Modulation
|
|
|
|
| 81 |
**inputs,
|
| 82 |
max_new_tokens=128,
|
| 83 |
do_sample=True,
|
| 84 |
+
temperature=0.6,
|
| 85 |
+
top_p=0.9,
|
| 86 |
+
repetition_penalty=1.2,
|
| 87 |
pad_token_id=model.tokenizer.eos_token_id
|
| 88 |
)
|
| 89 |
|
| 90 |
model.clear_modulation()
|
| 91 |
|
| 92 |
response = model.tokenizer.decode(out_ids[0], skip_special_tokens=True)
|
| 93 |
+
|
| 94 |
+
# Clean up response (Remove the prompt part)
|
| 95 |
+
if response.startswith(prompt):
|
| 96 |
+
response = response[len(prompt):].strip()
|
| 97 |
+
elif "Model:" in response:
|
| 98 |
+
response = response.split("Model:")[-1].strip()
|
| 99 |
+
|
| 100 |
|
| 101 |
print(f"MODEL: {response}")
|
| 102 |
print(f" [Modulation Norm: {torch.norm(modulation).item():.2f}]")
|