100XZX001 commited on
Commit
45d29c0
·
verified ·
1 Parent(s): 95b8e01

Update training.py

Browse files
Files changed (1) hide show
  1. training.py +4 -5
training.py CHANGED
@@ -1,8 +1,7 @@
1
  # training.py – Memory‑safe: Phi‑3‑mini + Expert Demos + Fast PPO (2 iterations)
2
  import os
3
- os.environ["TRITON_INTERPRET"] = "1" # force CPU interpretation, no ptxas
4
- os.environ["TRITON_CACHE_DIR"] = "/tmp/triton_cache"
5
- os.environ["TORCHINDUCTOR_CPP_WRAPPER"] = "0" # stay in Python # Issue #12: prevent OOM from parallel tokenization
6
 
7
  import torch._dynamo
8
  torch._dynamo.config.disable = True
@@ -79,7 +78,7 @@ def map_to_env(action: AgentAction):
79
  def load_model():
80
  model, tokenizer = FastLanguageModel.from_pretrained(
81
  model_name="unsloth/Phi-3-mini-4k-instruct-bnb-4bit",
82
- max_seq_length=2048,
83
  load_in_4bit=True,
84
  )
85
  model = FastLanguageModel.get_peft_model(
@@ -336,7 +335,7 @@ def generate_action_with_logprob(prompt, model, tokenizer, temperature=0.0, max_
336
  with torch.no_grad():
337
  outputs = model.generate(
338
  **inputs,
339
- max_new_tokens=128,
340
  do_sample=(temperature > 0),
341
  temperature=max(temperature, 0.01) if temperature > 0 else 1.0,
342
  min_new_tokens=1,
 
1
  # training.py – Memory‑safe: Phi‑3‑mini + Expert Demos + Fast PPO (2 iterations)
2
  import os
3
+ os.environ["TRITON_DISABLE"] = "1"
4
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"" # stay in Python # Issue #12: prevent OOM from parallel tokenization
 
5
 
6
  import torch._dynamo
7
  torch._dynamo.config.disable = True
 
78
  def load_model():
79
  model, tokenizer = FastLanguageModel.from_pretrained(
80
  model_name="unsloth/Phi-3-mini-4k-instruct-bnb-4bit",
81
+ max_seq_length=768,
82
  load_in_4bit=True,
83
  )
84
  model = FastLanguageModel.get_peft_model(
 
335
  with torch.no_grad():
336
  outputs = model.generate(
337
  **inputs,
338
+ max_new_tokens=64,
339
  do_sample=(temperature > 0),
340
  temperature=max(temperature, 0.01) if temperature > 0 else 1.0,
341
  min_new_tokens=1,