Spaces:
Sleeping
Sleeping
Update training.py
Browse files- training.py +4 -5
training.py
CHANGED
|
@@ -1,8 +1,7 @@
|
|
| 1 |
# training.py – Memory‑safe: Phi‑3‑mini + Expert Demos + Fast PPO (2 iterations)
|
| 2 |
import os
|
| 3 |
-
os.environ["
|
| 4 |
-
os.environ["
|
| 5 |
-
os.environ["TORCHINDUCTOR_CPP_WRAPPER"] = "0" # stay in Python # Issue #12: prevent OOM from parallel tokenization
|
| 6 |
|
| 7 |
import torch._dynamo
|
| 8 |
torch._dynamo.config.disable = True
|
|
@@ -79,7 +78,7 @@ def map_to_env(action: AgentAction):
|
|
| 79 |
def load_model():
|
| 80 |
model, tokenizer = FastLanguageModel.from_pretrained(
|
| 81 |
model_name="unsloth/Phi-3-mini-4k-instruct-bnb-4bit",
|
| 82 |
-
max_seq_length=
|
| 83 |
load_in_4bit=True,
|
| 84 |
)
|
| 85 |
model = FastLanguageModel.get_peft_model(
|
|
@@ -336,7 +335,7 @@ def generate_action_with_logprob(prompt, model, tokenizer, temperature=0.0, max_
|
|
| 336 |
with torch.no_grad():
|
| 337 |
outputs = model.generate(
|
| 338 |
**inputs,
|
| 339 |
-
max_new_tokens=
|
| 340 |
do_sample=(temperature > 0),
|
| 341 |
temperature=max(temperature, 0.01) if temperature > 0 else 1.0,
|
| 342 |
min_new_tokens=1,
|
|
|
|
| 1 |
# training.py – Memory‑safe: Phi‑3‑mini + Expert Demos + Fast PPO (2 iterations)
|
| 2 |
import os
|
| 3 |
+
os.environ["TRITON_DISABLE"] = "1"
|
| 4 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"" # stay in Python # Issue #12: prevent OOM from parallel tokenization
|
|
|
|
| 5 |
|
| 6 |
import torch._dynamo
|
| 7 |
torch._dynamo.config.disable = True
|
|
|
|
| 78 |
def load_model():
|
| 79 |
model, tokenizer = FastLanguageModel.from_pretrained(
|
| 80 |
model_name="unsloth/Phi-3-mini-4k-instruct-bnb-4bit",
|
| 81 |
+
max_seq_length=768,
|
| 82 |
load_in_4bit=True,
|
| 83 |
)
|
| 84 |
model = FastLanguageModel.get_peft_model(
|
|
|
|
| 335 |
with torch.no_grad():
|
| 336 |
outputs = model.generate(
|
| 337 |
**inputs,
|
| 338 |
+
max_new_tokens=64,
|
| 339 |
do_sample=(temperature > 0),
|
| 340 |
temperature=max(temperature, 0.01) if temperature > 0 else 1.0,
|
| 341 |
min_new_tokens=1,
|