Upload 137 files
Browse files- hugging/td_fuse/heal.py +17 -1
hugging/td_fuse/heal.py
CHANGED
|
@@ -29,6 +29,22 @@ from datasets import load_dataset
|
|
| 29 |
from .config import MergeConfig
|
| 30 |
|
| 31 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 32 |
def check_unsloth_available() -> bool:
|
| 33 |
"""Check if Unsloth is installed and working."""
|
| 34 |
try:
|
|
@@ -235,7 +251,7 @@ def apply_qlora_standard(
|
|
| 235 |
)
|
| 236 |
|
| 237 |
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
| 238 |
-
model =
|
| 239 |
model_path,
|
| 240 |
quantization_config=bnb_config,
|
| 241 |
device_map="auto",
|
|
|
|
| 29 |
from .config import MergeConfig
|
| 30 |
|
| 31 |
|
| 32 |
+
def _load_model_smart(checkpoint, **kwargs):
|
| 33 |
+
"""Load model — auto-detects Qwen3-VL and uses the correct class."""
|
| 34 |
+
from transformers import AutoConfig
|
| 35 |
+
try:
|
| 36 |
+
config = AutoConfig.from_pretrained(checkpoint, trust_remote_code=True)
|
| 37 |
+
model_type = getattr(config, 'model_type', '')
|
| 38 |
+
config_class = type(config).__name__.lower()
|
| 39 |
+
if 'qwen3_vl' in model_type or 'qwen3vl' in config_class:
|
| 40 |
+
from transformers import Qwen3VLForConditionalGeneration
|
| 41 |
+
print(f'[heal] Loading as Qwen3-VL model: {checkpoint}')
|
| 42 |
+
return Qwen3VLForConditionalGeneration.from_pretrained(checkpoint, **kwargs)
|
| 43 |
+
except Exception as e:
|
| 44 |
+
print(f'[heal] Auto-detect failed ({e}), using AutoModelForCausalLM')
|
| 45 |
+
return AutoModelForCausalLM.from_pretrained(checkpoint, **kwargs)
|
| 46 |
+
|
| 47 |
+
|
| 48 |
def check_unsloth_available() -> bool:
|
| 49 |
"""Check if Unsloth is installed and working."""
|
| 50 |
try:
|
|
|
|
| 251 |
)
|
| 252 |
|
| 253 |
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
| 254 |
+
model = _load_model_smart(
|
| 255 |
model_path,
|
| 256 |
quantization_config=bnb_config,
|
| 257 |
device_map="auto",
|