td-builder commited on
Commit
e123506
·
verified ·
1 Parent(s): d30aa8a

Upload 137 files

Browse files
Files changed (1) hide show
  1. hugging/td_fuse/heal.py +17 -1
hugging/td_fuse/heal.py CHANGED
@@ -29,6 +29,22 @@ from datasets import load_dataset
29
  from .config import MergeConfig
30
 
31
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  def check_unsloth_available() -> bool:
33
  """Check if Unsloth is installed and working."""
34
  try:
@@ -235,7 +251,7 @@ def apply_qlora_standard(
235
  )
236
 
237
  tokenizer = AutoTokenizer.from_pretrained(model_path)
238
- model = AutoModelForCausalLM.from_pretrained(
239
  model_path,
240
  quantization_config=bnb_config,
241
  device_map="auto",
 
29
  from .config import MergeConfig
30
 
31
 
32
+ def _load_model_smart(checkpoint, **kwargs):
33
+ """Load model — auto-detects Qwen3-VL and uses the correct class."""
34
+ from transformers import AutoConfig
35
+ try:
36
+ config = AutoConfig.from_pretrained(checkpoint, trust_remote_code=True)
37
+ model_type = getattr(config, 'model_type', '')
38
+ config_class = type(config).__name__.lower()
39
+ if 'qwen3_vl' in model_type or 'qwen3vl' in config_class:
40
+ from transformers import Qwen3VLForConditionalGeneration
41
+ print(f'[heal] Loading as Qwen3-VL model: {checkpoint}')
42
+ return Qwen3VLForConditionalGeneration.from_pretrained(checkpoint, **kwargs)
43
+ except Exception as e:
44
+ print(f'[heal] Auto-detect failed ({e}), using AutoModelForCausalLM')
45
+ return AutoModelForCausalLM.from_pretrained(checkpoint, **kwargs)
46
+
47
+
48
  def check_unsloth_available() -> bool:
49
  """Check if Unsloth is installed and working."""
50
  try:
 
251
  )
252
 
253
  tokenizer = AutoTokenizer.from_pretrained(model_path)
254
+ model = _load_model_smart(
255
  model_path,
256
  quantization_config=bnb_config,
257
  device_map="auto",