|
|
import os, json, torch, gradio as gr |
|
|
from datasets import load_dataset |
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig |
|
|
from peft import LoraConfig, PeftModel, get_peft_model |
|
|
from trl import SFTTrainer, SFTConfig |
|
|
|
|
|
|
|
|
BASE_MODEL = os.getenv("BASE_MODEL", "Qwen/Qwen2.5-0.5B-Instruct") |
|
|
ADAPTER_DIR = os.getenv("ADAPTER_DIR", "lora_adapter") |
|
|
TRAIN_PATH = os.getenv("TRAIN_PATH", "data/sft_train.jsonl") |
|
|
VAL_PATH = os.getenv("VAL_PATH", "data/sft_val.jsonl") |
|
|
|
|
|
|
|
|
_tokenizer = None |
|
|
_base_model = None |
|
|
_gen_model = None |
|
|
|
|
|
def load_base(load_in_4bit=None): |
|
|
global _tokenizer, _base_model |
|
|
if _tokenizer is None: |
|
|
_tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, use_fast=True) |
|
|
if _base_model is None: |
|
|
use_4bit = torch.cuda.is_available() if load_in_4bit is None else load_in_4bit |
|
|
if use_4bit: |
|
|
bnb = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type="nf4") |
|
|
_base_model = AutoModelForCausalLM.from_pretrained(BASE_MODEL, quantization_config=bnb, device_map="auto") |
|
|
else: |
|
|
_base_model = AutoModelForCausalLM.from_pretrained(BASE_MODEL, torch_dtype=torch.float32, device_map="cpu") |
|
|
return _tokenizer, _base_model |
|
|
|
|
|
|
|
|
def train_qlora(max_steps=500, lora_r=16, lora_alpha=32, lora_dropout=0.05, per_device_bs=1, grad_accum=8): |
|
|
|
|
|
if not os.path.exists(TRAIN_PATH): |
|
|
return f"[Error] Train file not found: {TRAIN_PATH}" |
|
|
if not os.path.exists(VAL_PATH): |
|
|
return f"[Error] Val file not found: {VAL_PATH}" |
|
|
|
|
|
train_ds = load_dataset("json", data_files=TRAIN_PATH)["train"] |
|
|
val_ds = load_dataset("json", data_files=VAL_PATH)["train"] |
|
|
|
|
|
tok, base = load_base(load_in_4bit=True) |
|
|
|
|
|
|
|
|
peft_cfg = LoraConfig( |
|
|
r=lora_r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, |
|
|
target_modules=["q_proj","k_proj","v_proj","o_proj","up_proj","down_proj"], |
|
|
task_type="CAUSAL_LM" |
|
|
) |
|
|
model = get_peft_model(base, peft_cfg) |
|
|
|
|
|
|
|
|
sft_cfg = SFTConfig( |
|
|
output_dir=ADAPTER_DIR, |
|
|
max_steps=int(max_steps), |
|
|
per_device_train_batch_size=per_device_bs, |
|
|
gradient_accumulation_steps=grad_accum, |
|
|
learning_rate=2e-4, |
|
|
bf16=torch.cuda.is_available(), |
|
|
logging_steps=20, |
|
|
save_steps=200, |
|
|
packing=False |
|
|
) |
|
|
trainer = SFTTrainer( |
|
|
model=model, |
|
|
tokenizer=tok, |
|
|
train_dataset=train_ds, |
|
|
eval_dataset=val_ds, |
|
|
args=sft_cfg |
|
|
) |
|
|
trainer.train() |
|
|
trainer.save_model(ADAPTER_DIR) |
|
|
return f"✅ Trained LoRA saved to: {ADAPTER_DIR}" |
|
|
|
|
|
def load_for_infer(adapter_dir=ADAPTER_DIR): |
|
|
global _gen_model |
|
|
tok, base = load_base(load_in_4bit=True) |
|
|
if adapter_dir and os.path.isdir(adapter_dir): |
|
|
_gen_model = PeftModel.from_pretrained(base, adapter_dir) |
|
|
else: |
|
|
_gen_model = base |
|
|
return "✅ Model ready (with LoRA)" if adapter_dir and os.path.isdir(adapter_dir) else "✅ Model ready (base only)" |
|
|
|
|
|
def generate(prompt, max_new_tokens=200, adapter_dir=ADAPTER_DIR): |
|
|
if _gen_model is None: |
|
|
load_for_infer(adapter_dir) |
|
|
tok, _ = load_base(load_in_4bit=True) |
|
|
inputs = tok(prompt, return_tensors="pt").to(_gen_model.device) |
|
|
with torch.no_grad(): |
|
|
out = _gen_model.generate(**inputs, max_new_tokens=int(max_new_tokens), do_sample=True, temperature=0.8) |
|
|
return tok.decode(out[0], skip_special_tokens=True) |
|
|
|
|
|
|
|
|
with gr.Blocks(title="WeChat Style QLoRA (Minimal)") as demo: |
|
|
gr.Markdown("## WeChat Style QLoRA — Minimal Demo \n" |
|
|
"使用 QLoRA 在私有 JSONL 上做最小监督微调(SFT)并进行推理。 \n" |
|
|
"**建议流程**:先用 CPU Basic 启动验证 → 切到 ZeroGPU/T4 训练 30–60 分钟 → 保存 LoRA → 返回 CPU 测试推理。") |
|
|
|
|
|
with gr.Tab("Train (QLoRA)"): |
|
|
gr.Markdown("**请先把 `data/sft_train.jsonl` 和 `data/sft_val.jsonl` 上传到本 Space 的 `data/` 目录。**") |
|
|
ms = gr.Number(value=500, label="max_steps") |
|
|
r = gr.Number(value=16, label="lora_r") |
|
|
a = gr.Number(value=32, label="lora_alpha") |
|
|
d = gr.Number(value=0.05, label="lora_dropout") |
|
|
bsz= gr.Number(value=1, label="per_device_train_batch_size") |
|
|
gas= gr.Number(value=8, label="gradient_accumulation_steps") |
|
|
train_btn = gr.Button("Start Training (GPU/ZeroGPU)") |
|
|
train_log = gr.Textbox(label="Training Log", interactive=False) |
|
|
train_btn.click(fn=train_qlora, inputs=[ms,r,a,d,bsz,gas], outputs=train_log) |
|
|
|
|
|
with gr.Tab("Inference"): |
|
|
gr.Markdown("默认会尝试加载 `lora_adapter/`。若还没训练,可直接用基础模型。") |
|
|
adapter = gr.Textbox(value=ADAPTER_DIR, label="LoRA adapter dir") |
|
|
load_btn = gr.Button("Load (with/without LoRA)") |
|
|
load_log = gr.Textbox(label="Status", interactive=False) |
|
|
load_btn.click(fn=load_for_infer, inputs=[adapter], outputs=load_log) |
|
|
|
|
|
prompt = gr.Textbox(lines=6, label="Prompt") |
|
|
gen_tokens = gr.Slider(32, 512, value=200, step=8, label="max_new_tokens") |
|
|
gen_btn = gr.Button("Generate") |
|
|
output = gr.Textbox(lines=12, label="Output") |
|
|
gen_btn.click(fn=generate, inputs=[prompt, gen_tokens, adapter], outputs=output) |
|
|
|
|
|
gr.Markdown("> 提示:训练时请先在 **Settings → Hardware** 切到 **ZeroGPU/T4**;完成后切回 **CPU Basic** 并停止 Space 以节省费用。") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.queue().launch() |
|
|
|