SSS-Distillation / train.py
Snow2222's picture
Update train.py
9c056a5 verified
import os
import json
import torch
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer, Trainer, TrainingArguments
from datasets import Dataset
# 设置 Hugging Face 缓存目录
os.environ['HF_HOME'] = '/tmp/huggingface_cache'
# 读取 Hugging Face Token
hf_token = os.getenv('HF_TOKEN') # 请确保环境变量 HF_TOKEN 已设置
if hf_token:
from huggingface_hub import HfFolder
HfFolder.save_token(hf_token)
else:
raise ValueError("Hugging Face token 未设置")
# ✅ 统一设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# ✅ **换成兼容的学生模型**
teacher_model_name = "Qwen/Qwen1.5-7B-Chat"
student_model_name = "Qwen/Qwen1.5-1.8B-Chat" # ✅ **换成 Qwen1.5 1.8B 版本**
# 加载教师模型
teacher = AutoModelForCausalLM.from_pretrained(
teacher_model_name,
trust_remote_code=True,
token=hf_token
).to(device)
teacher.eval()
# 加载**学生模型**
student = AutoModelForCausalLM.from_pretrained(
student_model_name,
trust_remote_code=True,
token=hf_token
).to(device)
tokenizer = AutoTokenizer.from_pretrained(
student_model_name, # ✅ **用 Qwen 词表,防止维度不匹配**
trust_remote_code=True,
token=hf_token
)
# **处理 pad_token 问题**
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# 读取数据文件
with open('data.json', 'r', encoding='utf-8') as f:
data = json.load(f)
# 检查数据格式
if not isinstance(data, list):
raise ValueError("data.json 格式错误,需要是一个列表!")
# **格式化数据(加 `chat_template`)**
def format_chat(example):
instruction = example["instruction"]
output = example["output"]
return f"<|im_start|>system\n你是一个粉丝通软件的智能客服助手。\n<|im_end|>\n<|im_start|>user\n{instruction}<|im_end|>\n<|im_start|>assistant\n{output}<|im_end|>"
# **数据预处理**
def preprocess_data(example):
inputs = tokenizer(example["instruction"], truncation=True, padding="max_length", max_length=128)
labels = tokenizer(example["output"], truncation=True, padding="max_length", max_length=128)
return {
"input_ids": inputs["input_ids"],
"attention_mask": inputs["attention_mask"],
"labels": labels["input_ids"]
}
# def preprocess_data(example):
# formatted_text = format_chat(example)
# tokens = tokenizer(formatted_text, truncation=True, padding="max_length", max_length=128)
# return {
# "input_ids": tokens["input_ids"],
# "attention_mask": tokens["attention_mask"],
# "labels": tokens["input_ids"]
# }
dataset = Dataset.from_list(data)
dataset = dataset.map(preprocess_data, batched=True)
# ✅ **修正 KL 散度计算**
class DistillationTrainer(Trainer):
def __init__(self, teacher, *args, **kwargs):
super().__init__(*args, **kwargs)
self.teacher = teacher.to(device)
def compute_loss(self, model, inputs, return_outputs=False):
labels = inputs["input_ids"]
# **学生模型的输出**
outputs_student = model(**inputs)
logits_student = outputs_student.logits
# **教师模型的输出**
with torch.no_grad():
inputs_on_device = {k: v.to(device) for k, v in inputs.items()}
outputs_teacher = self.teacher(**inputs_on_device)
logits_teacher = outputs_teacher.logits
temperature = 2.0
# **修正维度**
student_log_probs = torch.nn.functional.log_softmax(logits_student / temperature, dim=-1)
teacher_probs = torch.nn.functional.softmax(logits_teacher / temperature, dim=-1)
# ✅ **确保两个 logits 维度匹配**
if student_log_probs.shape != teacher_probs.shape:
min_dim = min(student_log_probs.shape[-1], teacher_probs.shape[-1])
student_log_probs = student_log_probs[..., :min_dim]
teacher_probs = teacher_probs[..., :min_dim]
# **计算 KL 散度**
kl_loss = torch.nn.functional.kl_div(student_log_probs, teacher_probs, reduction="batchmean") * (temperature ** 2)
# **计算交叉熵损失**
ce_loss_fct = torch.nn.CrossEntropyLoss()
ce_loss = ce_loss_fct(logits_student.view(-1, logits_student.size(-1)), labels.view(-1))
# **混合损失**
alpha = 0.5
loss = alpha * ce_loss + (1 - alpha) * kl_loss
return (loss, outputs_student) if return_outputs else loss
def training_step(self, model, inputs, *args, **kwargs):
"""确保所有输入都在 GPU 上"""
model.train()
inputs = {k: v.to(device) for k, v in self._prepare_inputs(inputs).items()}
loss = self.compute_loss(model, inputs)
return loss
# 训练参数
# ✅ 移除 `use_cache` 选项
training_args = TrainingArguments(
output_dir="/tmp/distilled_model",
num_train_epochs=3,
per_device_train_batch_size=8,
learning_rate=2e-5,
weight_decay=0.01,
evaluation_strategy="epoch",
logging_steps=100,
save_strategy="epoch",
remove_unused_columns=False,
gradient_checkpointing=True,
fp16=False, # ✅ 禁用 fp16,避免 _scale=None 错误
bf16=True if torch.cuda.is_available() else False # ✅ 仅在支持 bf16 时启用
)
# ✅ **手动禁用 `use_cache`**
student.config.use_cache = False # 🔥 这样就不会影响 `TrainingArguments`,但依然禁用了 `use_cache`
# 初始化 Trainer
trainer = DistillationTrainer(
teacher=teacher,
model=student,
args=training_args,
train_dataset=dataset,
eval_dataset=dataset
)
# 开始训练
trainer.train()
# **保存模型**
student.push_to_hub("Snow2222/fst-nnn", use_auth_token=hf_token)
tokenizer.push_to_hub("Snow2222/fst-nnn", use_auth_token=hf_token)
# ✅ 部署 Gradio Web 界面
print("🎉 训练完成,启动 Gradio Web 界面...")
# **模型 ID**
model_id = "Snow2222/fst-nnn"
# **选择设备**
device = "cuda" if torch.cuda.is_available() else "cpu"
# **加载模型**
print("🚀 正在加载模型...")
model = AutoModelForCausalLM.from_pretrained(
model_id,
token=hf_token,
trust_remote_code=True
).to(device)
tokenizer = AutoTokenizer.from_pretrained(model_id, token=hf_token)
print("✅ 模型加载成功!")
# **Gradio 交互函数**
def chat_response(prompt):
chat_input = f"<|im_start|>system\n你是一个粉丝通软件的智能客服助手。\n<|im_end|>\n<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n"
inputs = tokenizer(chat_input, return_tensors="pt").to(device)
outputs = model.generate(
**inputs,
max_length=100, # ✅ 控制生成长度
do_sample=True,
temperature=0.7
)
result = tokenizer.decode(outputs[0], skip_special_tokens=True)
return result
# **创建 Gradio 界面**
iface = gr.Interface(
fn=chat_response,
inputs=gr.Textbox(lines=2, placeholder="请输入你的问题..."),
outputs="text",
title="粉丝通 AI 客服",
description="基于 Snow2222/fst-nnn 训练的 AI 模型,自动回答你的问题。",
allow_flagging="never"
)
# **运行 Gradio**
iface.launch(server_name="0.0.0.0", server_port=7860)