xianyu564
优化 load_base 函数,调整 load_in_4bit 参数的默认值并改进模型加载逻辑
027ebe0
import os, json, torch, gradio as gr
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from peft import LoraConfig, PeftModel, get_peft_model
from trl import SFTTrainer, SFTConfig
# ==== 基本配置(可改小模型/步数)====
BASE_MODEL = os.getenv("BASE_MODEL", "Qwen/Qwen2.5-0.5B-Instruct")
ADAPTER_DIR = os.getenv("ADAPTER_DIR", "lora_adapter")
TRAIN_PATH = os.getenv("TRAIN_PATH", "data/sft_train.jsonl")
VAL_PATH = os.getenv("VAL_PATH", "data/sft_val.jsonl")
# ==== 懒加载:先占位,按钮点了再真正下载 ====
_tokenizer = None
_base_model = None
_gen_model = None # 推理用(可能带LoRA)
def load_base(load_in_4bit=None):
global _tokenizer, _base_model
if _tokenizer is None:
_tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, use_fast=True)
if _base_model is None:
use_4bit = torch.cuda.is_available() if load_in_4bit is None else load_in_4bit
if use_4bit:
bnb = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type="nf4")
_base_model = AutoModelForCausalLM.from_pretrained(BASE_MODEL, quantization_config=bnb, device_map="auto")
else:
_base_model = AutoModelForCausalLM.from_pretrained(BASE_MODEL, torch_dtype=torch.float32, device_map="cpu")
return _tokenizer, _base_model
def train_qlora(max_steps=500, lora_r=16, lora_alpha=32, lora_dropout=0.05, per_device_bs=1, grad_accum=8):
# 准备数据
if not os.path.exists(TRAIN_PATH):
return f"[Error] Train file not found: {TRAIN_PATH}"
if not os.path.exists(VAL_PATH):
return f"[Error] Val file not found: {VAL_PATH}"
train_ds = load_dataset("json", data_files=TRAIN_PATH)["train"]
val_ds = load_dataset("json", data_files=VAL_PATH)["train"]
tok, base = load_base(load_in_4bit=True)
# LoRA 配置
peft_cfg = LoraConfig(
r=lora_r, lora_alpha=lora_alpha, lora_dropout=lora_dropout,
target_modules=["q_proj","k_proj","v_proj","o_proj","up_proj","down_proj"],
task_type="CAUSAL_LM"
)
model = get_peft_model(base, peft_cfg)
# 训练配置(TRL)
sft_cfg = SFTConfig(
output_dir=ADAPTER_DIR,
max_steps=int(max_steps),
per_device_train_batch_size=per_device_bs,
gradient_accumulation_steps=grad_accum,
learning_rate=2e-4,
bf16=torch.cuda.is_available(),
logging_steps=20,
save_steps=200,
packing=False
)
trainer = SFTTrainer(
model=model,
tokenizer=tok,
train_dataset=train_ds,
eval_dataset=val_ds,
args=sft_cfg
)
trainer.train()
trainer.save_model(ADAPTER_DIR)
return f"✅ Trained LoRA saved to: {ADAPTER_DIR}"
def load_for_infer(adapter_dir=ADAPTER_DIR):
global _gen_model
tok, base = load_base(load_in_4bit=True)
if adapter_dir and os.path.isdir(adapter_dir):
_gen_model = PeftModel.from_pretrained(base, adapter_dir)
else:
_gen_model = base
return "✅ Model ready (with LoRA)" if adapter_dir and os.path.isdir(adapter_dir) else "✅ Model ready (base only)"
def generate(prompt, max_new_tokens=200, adapter_dir=ADAPTER_DIR):
if _gen_model is None:
load_for_infer(adapter_dir)
tok, _ = load_base(load_in_4bit=True)
inputs = tok(prompt, return_tensors="pt").to(_gen_model.device)
with torch.no_grad():
out = _gen_model.generate(**inputs, max_new_tokens=int(max_new_tokens), do_sample=True, temperature=0.8)
return tok.decode(out[0], skip_special_tokens=True)
# ==== Gradio UI ====
with gr.Blocks(title="WeChat Style QLoRA (Minimal)") as demo:
gr.Markdown("## WeChat Style QLoRA — Minimal Demo \n"
"使用 QLoRA 在私有 JSONL 上做最小监督微调(SFT)并进行推理。 \n"
"**建议流程**:先用 CPU Basic 启动验证 → 切到 ZeroGPU/T4 训练 30–60 分钟 → 保存 LoRA → 返回 CPU 测试推理。")
with gr.Tab("Train (QLoRA)"):
gr.Markdown("**请先把 `data/sft_train.jsonl` 和 `data/sft_val.jsonl` 上传到本 Space 的 `data/` 目录。**")
ms = gr.Number(value=500, label="max_steps")
r = gr.Number(value=16, label="lora_r")
a = gr.Number(value=32, label="lora_alpha")
d = gr.Number(value=0.05, label="lora_dropout")
bsz= gr.Number(value=1, label="per_device_train_batch_size")
gas= gr.Number(value=8, label="gradient_accumulation_steps")
train_btn = gr.Button("Start Training (GPU/ZeroGPU)")
train_log = gr.Textbox(label="Training Log", interactive=False)
train_btn.click(fn=train_qlora, inputs=[ms,r,a,d,bsz,gas], outputs=train_log)
with gr.Tab("Inference"):
gr.Markdown("默认会尝试加载 `lora_adapter/`。若还没训练,可直接用基础模型。")
adapter = gr.Textbox(value=ADAPTER_DIR, label="LoRA adapter dir")
load_btn = gr.Button("Load (with/without LoRA)")
load_log = gr.Textbox(label="Status", interactive=False)
load_btn.click(fn=load_for_infer, inputs=[adapter], outputs=load_log)
prompt = gr.Textbox(lines=6, label="Prompt")
gen_tokens = gr.Slider(32, 512, value=200, step=8, label="max_new_tokens")
gen_btn = gr.Button("Generate")
output = gr.Textbox(lines=12, label="Output")
gen_btn.click(fn=generate, inputs=[prompt, gen_tokens, adapter], outputs=output)
gr.Markdown("> 提示:训练时请先在 **Settings → Hardware** 切到 **ZeroGPU/T4**;完成后切回 **CPU Basic** 并停止 Space 以节省费用。")
if __name__ == "__main__":
demo.queue().launch()