md896 commited on
Commit
e5262a1
·
1 Parent(s): d21de11

Simplify HF training stack: remove unsloth/vllm path, use plain transformers AutoModel + single OpenEnv reward.

Browse files
Files changed (1) hide show
  1. ultimate_sota_training.py +12 -89
ultimate_sota_training.py CHANGED
@@ -1,5 +1,5 @@
1
  """
2
- Unsloth + OpenEnv GRPO training (production-oriented).
3
 
4
  Produces real training artifacts (trainer log_history, metrics JSON, reward plots) and
5
  optional Hub push of LoRA weights. Every execution reward calls your live Space (or
@@ -30,7 +30,6 @@ Key stability choices:
30
 
31
  from __future__ import annotations
32
 
33
- import importlib.metadata as importlib_metadata
34
  import json
35
  import os
36
  import random
@@ -96,14 +95,6 @@ def bootstrap_deps() -> None:
96
  if os.environ.get("WANDB_API_KEY"):
97
  _pip(["install", "--break-system-packages", "wandb"], check=False)
98
 
99
- _pip(
100
- [
101
- "install",
102
- "--break-system-packages",
103
- "unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git",
104
- ]
105
- )
106
-
107
  _pip(
108
  [
109
  "install",
@@ -119,7 +110,7 @@ def bootstrap_deps() -> None:
119
  _pip(["uninstall", "-y", "torchao"], check=False)
120
  _pip(["uninstall", "--break-system-packages", "-y", "torchvision", "torchaudio"], check=False)
121
 
122
- # Do not import transformers/trl here. Unsloth must be imported first later.
123
 
124
 
125
  bootstrap_deps()
@@ -127,74 +118,14 @@ bootstrap_deps()
127
  import httpx
128
  import torch
129
  from datasets import Dataset
130
-
131
- # --- CRITICAL FIXES FOR HF JOBS ---
132
- # 0. Unsloth checks importlib.metadata.version("vllm") at import time.
133
- # In text-only GRPO runs we don't install vllm, so return a dummy version instead
134
- # of crashing with PackageNotFoundError.
135
- _real_pkg_version = importlib_metadata.version
136
-
137
-
138
- def _safe_pkg_version(dist_name: str) -> str:
139
- if dist_name == "vllm":
140
- return "0.0.0"
141
- return _real_pkg_version(dist_name)
142
-
143
-
144
- importlib_metadata.version = _safe_pkg_version
145
-
146
- # 1. Mock vllm: TRL's GRPOTrainer (v0.18+) has a buggy import path that hard-fails if vllm is missing.
147
- # We must provide a mock that satisfies both 'import' and 'importlib.util.find_spec'.
148
- import sys
149
- import types
150
- import importlib.machinery
151
-
152
- def mock_vllm_hierarchy():
153
- pkg_names = [
154
- "vllm",
155
- "vllm.distributed",
156
- "vllm.distributed.device_communicators",
157
- "vllm.model_executor",
158
- "vllm.model_executor.parallel_utils",
159
- ]
160
- leaf_names = [
161
- "vllm.distributed.device_communicators.pynccl",
162
- ]
163
-
164
- # Create proper package-like modules with submodule_search_locations so
165
- # unsloth's import fixes that inspect package paths don't crash.
166
- for m_name in pkg_names:
167
- mod = types.ModuleType(m_name)
168
- mod.__package__ = m_name
169
- mod.__path__ = [f"/tmp/mock_{m_name.replace('.', '_')}"]
170
- spec = importlib.machinery.ModuleSpec(m_name, loader=None, is_package=True)
171
- spec.submodule_search_locations = mod.__path__
172
- mod.__spec__ = spec
173
- sys.modules[m_name] = mod
174
-
175
- for m_name in leaf_names:
176
- mod = types.ModuleType(m_name)
177
- mod.__package__ = m_name.rsplit(".", 1)[0]
178
- mod.__spec__ = importlib.machinery.ModuleSpec(m_name, loader=None, is_package=False)
179
- sys.modules[m_name] = mod
180
-
181
- mock_vllm_hierarchy()
182
-
183
- # Import Unsloth before transformers / trl for its patching path.
184
- from unsloth import FastLanguageModel
185
-
186
- # 2. Mock llm_blender: Fix for TRANSFORMERS_CACHE removal in transformers 4.40+.
187
- import transformers.utils.hub
188
- if not hasattr(transformers.utils.hub, "TRANSFORMERS_CACHE"):
189
- transformers.utils.hub.TRANSFORMERS_CACHE = "/tmp"
190
-
191
  from trl import GRPOConfig, GRPOTrainer
192
 
193
  # --- 1. CONFIGURATION (env-first; defaults match openenv.yaml) ---
194
  _DEFAULT_OPENENV_BASE = "https://md896-sql-debug-env.hf.space"
195
  BYPASS_HEADERS: Dict[str, str] = {}
196
 
197
- MODEL_NAME = os.environ.get("TRAIN_MODEL_NAME", "unsloth/Qwen2.5-Coder-7B-Instruct")
198
 
199
 
200
  def get_bridge_url() -> str:
@@ -410,33 +341,25 @@ def _resolve_report_to() -> str:
410
  return raw
411
 
412
 
413
- # --- 4. Unsloth GRPO training loop (live OpenEnv rewards) ---
414
  def run_sota_train():
415
  max_steps = int(os.environ.get("TRAIN_MAX_STEPS", "200"))
416
  out_dir = os.environ.get("OUTPUT_DIR", "./sota_results")
417
 
418
- print(f"Starting Unsloth GRPO on {MODEL_NAME}...")
419
  print(
420
  f"OpenEnv={get_bridge_url()} | max_steps={max_steps} | "
421
  f"rows_per_task={os.environ.get('ROWS_PER_TASK', '48')} | "
422
  f"report_to={_resolve_report_to()}"
423
  )
424
 
425
- max_seq = int(os.environ.get("MAX_SEQ_LENGTH", "1024"))
426
- model, tokenizer = FastLanguageModel.from_pretrained(
427
- model_name=MODEL_NAME,
428
- max_seq_length=max_seq,
429
- load_in_4bit=True,
430
- )
431
-
432
  tokenizer.pad_token = tokenizer.eos_token
433
-
434
- # APPLY UNSLOTH LORA ADAPTERS
435
- model = FastLanguageModel.get_peft_model(
436
- model,
437
- r=16,
438
- lora_alpha=16,
439
- target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
440
  )
441
 
442
  train_dataset = make_real_dataset()
 
1
  """
2
+ OpenEnv GRPO training (production-oriented, simple stack).
3
 
4
  Produces real training artifacts (trainer log_history, metrics JSON, reward plots) and
5
  optional Hub push of LoRA weights. Every execution reward calls your live Space (or
 
30
 
31
  from __future__ import annotations
32
 
 
33
  import json
34
  import os
35
  import random
 
95
  if os.environ.get("WANDB_API_KEY"):
96
  _pip(["install", "--break-system-packages", "wandb"], check=False)
97
 
 
 
 
 
 
 
 
 
98
  _pip(
99
  [
100
  "install",
 
110
  _pip(["uninstall", "-y", "torchao"], check=False)
111
  _pip(["uninstall", "--break-system-packages", "-y", "torchvision", "torchaudio"], check=False)
112
 
113
+ # Keep bootstrap import-free; training imports happen below.
114
 
115
 
116
  bootstrap_deps()
 
118
  import httpx
119
  import torch
120
  from datasets import Dataset
121
+ from transformers import AutoModelForCausalLM, AutoTokenizer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
122
  from trl import GRPOConfig, GRPOTrainer
123
 
124
  # --- 1. CONFIGURATION (env-first; defaults match openenv.yaml) ---
125
  _DEFAULT_OPENENV_BASE = "https://md896-sql-debug-env.hf.space"
126
  BYPASS_HEADERS: Dict[str, str] = {}
127
 
128
+ MODEL_NAME = os.environ.get("TRAIN_MODEL_NAME", "Qwen/Qwen2.5-Coder-0.5B-Instruct")
129
 
130
 
131
  def get_bridge_url() -> str:
 
341
  return raw
342
 
343
 
344
+ # --- 4. Simple GRPO training loop (live OpenEnv rewards) ---
345
  def run_sota_train():
346
  max_steps = int(os.environ.get("TRAIN_MAX_STEPS", "200"))
347
  out_dir = os.environ.get("OUTPUT_DIR", "./sota_results")
348
 
349
+ print(f"Starting GRPO on {MODEL_NAME}...")
350
  print(
351
  f"OpenEnv={get_bridge_url()} | max_steps={max_steps} | "
352
  f"rows_per_task={os.environ.get('ROWS_PER_TASK', '48')} | "
353
  f"report_to={_resolve_report_to()}"
354
  )
355
 
356
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
 
 
 
 
 
 
357
  tokenizer.pad_token = tokenizer.eos_token
358
+ torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
359
+ model = AutoModelForCausalLM.from_pretrained(
360
+ MODEL_NAME,
361
+ torch_dtype=torch_dtype,
362
+ device_map="auto",
 
 
363
  )
364
 
365
  train_dataset = make_real_dataset()