File size: 5,703 Bytes
c85212b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
027ebe0
c85212b
 
 
 
027ebe0
 
 
 
c85212b
027ebe0
c85212b
 
027ebe0
c85212b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e64ce88
c85212b
 
 
 
 
 
 
 
 
 
 
 
 
e64ce88
c85212b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import os, json, torch, gradio as gr
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from peft import LoraConfig, PeftModel, get_peft_model
from trl import SFTTrainer, SFTConfig

# ==== 基本配置(可改小模型/步数)====
BASE_MODEL = os.getenv("BASE_MODEL", "Qwen/Qwen2.5-0.5B-Instruct")
ADAPTER_DIR = os.getenv("ADAPTER_DIR", "lora_adapter")
TRAIN_PATH = os.getenv("TRAIN_PATH", "data/sft_train.jsonl")
VAL_PATH   = os.getenv("VAL_PATH", "data/sft_val.jsonl")

# ==== 懒加载:先占位,按钮点了再真正下载 ====
_tokenizer = None
_base_model = None
_gen_model = None  # 推理用(可能带LoRA)

def load_base(load_in_4bit=None):
    global _tokenizer, _base_model
    if _tokenizer is None:
        _tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, use_fast=True)
    if _base_model is None:
        use_4bit = torch.cuda.is_available() if load_in_4bit is None else load_in_4bit
        if use_4bit:
            bnb = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type="nf4")
            _base_model = AutoModelForCausalLM.from_pretrained(BASE_MODEL, quantization_config=bnb, device_map="auto")
        else:
            _base_model = AutoModelForCausalLM.from_pretrained(BASE_MODEL, torch_dtype=torch.float32, device_map="cpu")
    return _tokenizer, _base_model


def train_qlora(max_steps=500, lora_r=16, lora_alpha=32, lora_dropout=0.05, per_device_bs=1, grad_accum=8):
    # 准备数据
    if not os.path.exists(TRAIN_PATH):
        return f"[Error] Train file not found: {TRAIN_PATH}"
    if not os.path.exists(VAL_PATH):
        return f"[Error] Val file not found: {VAL_PATH}"

    train_ds = load_dataset("json", data_files=TRAIN_PATH)["train"]
    val_ds   = load_dataset("json", data_files=VAL_PATH)["train"]

    tok, base = load_base(load_in_4bit=True)

    # LoRA 配置
    peft_cfg = LoraConfig(
        r=lora_r, lora_alpha=lora_alpha, lora_dropout=lora_dropout,
        target_modules=["q_proj","k_proj","v_proj","o_proj","up_proj","down_proj"],
        task_type="CAUSAL_LM"
    )
    model = get_peft_model(base, peft_cfg)

    # 训练配置(TRL)
    sft_cfg = SFTConfig(
        output_dir=ADAPTER_DIR,
        max_steps=int(max_steps),
        per_device_train_batch_size=per_device_bs,
        gradient_accumulation_steps=grad_accum,
        learning_rate=2e-4,
        bf16=torch.cuda.is_available(),
        logging_steps=20,
        save_steps=200,
        packing=False
    )
    trainer = SFTTrainer(
        model=model,
        tokenizer=tok,
        train_dataset=train_ds,
        eval_dataset=val_ds,
        args=sft_cfg
    )
    trainer.train()
    trainer.save_model(ADAPTER_DIR)
    return f"✅ Trained LoRA saved to: {ADAPTER_DIR}"

def load_for_infer(adapter_dir=ADAPTER_DIR):
    global _gen_model
    tok, base = load_base(load_in_4bit=True)
    if adapter_dir and os.path.isdir(adapter_dir):
        _gen_model = PeftModel.from_pretrained(base, adapter_dir)
    else:
        _gen_model = base
    return "✅ Model ready (with LoRA)" if adapter_dir and os.path.isdir(adapter_dir) else "✅ Model ready (base only)"

def generate(prompt, max_new_tokens=200, adapter_dir=ADAPTER_DIR):
    if _gen_model is None:
        load_for_infer(adapter_dir)
    tok, _ = load_base(load_in_4bit=True)
    inputs = tok(prompt, return_tensors="pt").to(_gen_model.device)
    with torch.no_grad():
        out = _gen_model.generate(**inputs, max_new_tokens=int(max_new_tokens), do_sample=True, temperature=0.8)
    return tok.decode(out[0], skip_special_tokens=True)

# ==== Gradio UI ====
with gr.Blocks(title="WeChat Style QLoRA (Minimal)") as demo:
    gr.Markdown("## WeChat Style QLoRA — Minimal Demo  \n"
                "使用 QLoRA 在私有 JSONL 上做最小监督微调(SFT)并进行推理。  \n"
                "**建议流程**:先用 CPU Basic 启动验证 → 切到 ZeroGPU/T4 训练 30–60 分钟 → 保存 LoRA → 返回 CPU 测试推理。")

    with gr.Tab("Train (QLoRA)"):
        gr.Markdown("**请先把 `data/sft_train.jsonl` 和 `data/sft_val.jsonl` 上传到本 Space 的 `data/` 目录。**")
        ms = gr.Number(value=500, label="max_steps")
        r  = gr.Number(value=16, label="lora_r")
        a  = gr.Number(value=32, label="lora_alpha")
        d  = gr.Number(value=0.05, label="lora_dropout")
        bsz= gr.Number(value=1, label="per_device_train_batch_size")
        gas= gr.Number(value=8, label="gradient_accumulation_steps")
        train_btn = gr.Button("Start Training (GPU/ZeroGPU)")
        train_log = gr.Textbox(label="Training Log", interactive=False)
        train_btn.click(fn=train_qlora, inputs=[ms,r,a,d,bsz,gas], outputs=train_log)

    with gr.Tab("Inference"):
        gr.Markdown("默认会尝试加载 `lora_adapter/`。若还没训练,可直接用基础模型。")
        adapter = gr.Textbox(value=ADAPTER_DIR, label="LoRA adapter dir")
        load_btn = gr.Button("Load (with/without LoRA)")
        load_log = gr.Textbox(label="Status", interactive=False)
        load_btn.click(fn=load_for_infer, inputs=[adapter], outputs=load_log)

        prompt = gr.Textbox(lines=6, label="Prompt")
        gen_tokens = gr.Slider(32, 512, value=200, step=8, label="max_new_tokens")
        gen_btn = gr.Button("Generate")
        output = gr.Textbox(lines=12, label="Output")
        gen_btn.click(fn=generate, inputs=[prompt, gen_tokens, adapter], outputs=output)

    gr.Markdown("> 提示:训练时请先在 **Settings → Hardware** 切到 **ZeroGPU/T4**;完成后切回 **CPU Basic** 并停止 Space 以节省费用。")

if __name__ == "__main__":
    demo.queue().launch()