Spaces:
Runtime error
Runtime error
File size: 7,248 Bytes
a8bcc65 0c4575f eff8ebb 0c4575f a8bcc65 e131b6c 51b81ca 65d3709 51b81ca a5f3d74 a8bcc65 0c4575f a8bcc65 e131b6c 93bf8ad bb7b0e9 93bf8ad 0c4575f 93bf8ad 0c4575f 93bf8ad 0c4575f 93bf8ad 0c4575f f5d7ee2 0c4575f 93bf8ad f5d7ee2 93bf8ad f5d7ee2 51b81ca 93bf8ad a5f3d74 0b272a1 93bf8ad a8bcc65 e131b6c 93bf8ad 51b81ca b7e2d99 0c4575f 93bf8ad b7e2d99 9c056a5 515c81d 9c056a5 515c81d 9c056a5 b7e2d99 515c81d 0c4575f b7e2d99 93bf8ad 0c4575f 51b81ca 93bf8ad 51b81ca eb81ebd 0c4575f 51b81ca 93bf8ad 0c4575f e131b6c 93bf8ad 0c4575f 93bf8ad bb7b0e9 0c4575f 93bf8ad 0c4575f 51b81ca 93bf8ad 0c4575f 93bf8ad 0c4575f 93bf8ad 0c4575f e131b6c 93bf8ad eb81ebd 93bf8ad eb81ebd 51b81ca 5f81635 e131b6c aa9f3ba e131b6c 0c4575f e131b6c 0c4575f 93bf8ad 5dc7c29 e131b6c 5f81635 51b81ca 0c4575f 93bf8ad 0c4575f e131b6c a8bcc65 93bf8ad e131b6c a8bcc65 93bf8ad 51b81ca bb7b0e9 eff8ebb b7e2d99 7e74c10 eff8ebb b7e2d99 7e74c10 b7e2d99 7e74c10 eff8ebb b7e2d99 7e74c10 eff8ebb 7e74c10 eff8ebb 7e74c10 eff8ebb 7e74c10 eff8ebb 7e74c10 eff8ebb | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 | import os
import json
import torch
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer, Trainer, TrainingArguments
from datasets import Dataset
# 设置 Hugging Face 缓存目录
os.environ['HF_HOME'] = '/tmp/huggingface_cache'
# 读取 Hugging Face Token
hf_token = os.getenv('HF_TOKEN') # 请确保环境变量 HF_TOKEN 已设置
if hf_token:
from huggingface_hub import HfFolder
HfFolder.save_token(hf_token)
else:
raise ValueError("Hugging Face token 未设置")
# ✅ 统一设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# ✅ **换成兼容的学生模型**
teacher_model_name = "Qwen/Qwen1.5-7B-Chat"
student_model_name = "Qwen/Qwen1.5-1.8B-Chat" # ✅ **换成 Qwen1.5 1.8B 版本**
# 加载教师模型
teacher = AutoModelForCausalLM.from_pretrained(
teacher_model_name,
trust_remote_code=True,
token=hf_token
).to(device)
teacher.eval()
# 加载**学生模型**
student = AutoModelForCausalLM.from_pretrained(
student_model_name,
trust_remote_code=True,
token=hf_token
).to(device)
tokenizer = AutoTokenizer.from_pretrained(
student_model_name, # ✅ **用 Qwen 词表,防止维度不匹配**
trust_remote_code=True,
token=hf_token
)
# **处理 pad_token 问题**
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# 读取数据文件
with open('data.json', 'r', encoding='utf-8') as f:
data = json.load(f)
# 检查数据格式
if not isinstance(data, list):
raise ValueError("data.json 格式错误,需要是一个列表!")
# **格式化数据(加 `chat_template`)**
def format_chat(example):
instruction = example["instruction"]
output = example["output"]
return f"<|im_start|>system\n你是一个粉丝通软件的智能客服助手。\n<|im_end|>\n<|im_start|>user\n{instruction}<|im_end|>\n<|im_start|>assistant\n{output}<|im_end|>"
# **数据预处理**
def preprocess_data(example):
inputs = tokenizer(example["instruction"], truncation=True, padding="max_length", max_length=128)
labels = tokenizer(example["output"], truncation=True, padding="max_length", max_length=128)
return {
"input_ids": inputs["input_ids"],
"attention_mask": inputs["attention_mask"],
"labels": labels["input_ids"]
}
# def preprocess_data(example):
# formatted_text = format_chat(example)
# tokens = tokenizer(formatted_text, truncation=True, padding="max_length", max_length=128)
# return {
# "input_ids": tokens["input_ids"],
# "attention_mask": tokens["attention_mask"],
# "labels": tokens["input_ids"]
# }
dataset = Dataset.from_list(data)
dataset = dataset.map(preprocess_data, batched=True)
# ✅ **修正 KL 散度计算**
class DistillationTrainer(Trainer):
def __init__(self, teacher, *args, **kwargs):
super().__init__(*args, **kwargs)
self.teacher = teacher.to(device)
def compute_loss(self, model, inputs, return_outputs=False):
labels = inputs["input_ids"]
# **学生模型的输出**
outputs_student = model(**inputs)
logits_student = outputs_student.logits
# **教师模型的输出**
with torch.no_grad():
inputs_on_device = {k: v.to(device) for k, v in inputs.items()}
outputs_teacher = self.teacher(**inputs_on_device)
logits_teacher = outputs_teacher.logits
temperature = 2.0
# **修正维度**
student_log_probs = torch.nn.functional.log_softmax(logits_student / temperature, dim=-1)
teacher_probs = torch.nn.functional.softmax(logits_teacher / temperature, dim=-1)
# ✅ **确保两个 logits 维度匹配**
if student_log_probs.shape != teacher_probs.shape:
min_dim = min(student_log_probs.shape[-1], teacher_probs.shape[-1])
student_log_probs = student_log_probs[..., :min_dim]
teacher_probs = teacher_probs[..., :min_dim]
# **计算 KL 散度**
kl_loss = torch.nn.functional.kl_div(student_log_probs, teacher_probs, reduction="batchmean") * (temperature ** 2)
# **计算交叉熵损失**
ce_loss_fct = torch.nn.CrossEntropyLoss()
ce_loss = ce_loss_fct(logits_student.view(-1, logits_student.size(-1)), labels.view(-1))
# **混合损失**
alpha = 0.5
loss = alpha * ce_loss + (1 - alpha) * kl_loss
return (loss, outputs_student) if return_outputs else loss
def training_step(self, model, inputs, *args, **kwargs):
"""确保所有输入都在 GPU 上"""
model.train()
inputs = {k: v.to(device) for k, v in self._prepare_inputs(inputs).items()}
loss = self.compute_loss(model, inputs)
return loss
# 训练参数
# ✅ 移除 `use_cache` 选项
training_args = TrainingArguments(
output_dir="/tmp/distilled_model",
num_train_epochs=3,
per_device_train_batch_size=8,
learning_rate=2e-5,
weight_decay=0.01,
evaluation_strategy="epoch",
logging_steps=100,
save_strategy="epoch",
remove_unused_columns=False,
gradient_checkpointing=True,
fp16=False, # ✅ 禁用 fp16,避免 _scale=None 错误
bf16=True if torch.cuda.is_available() else False # ✅ 仅在支持 bf16 时启用
)
# ✅ **手动禁用 `use_cache`**
student.config.use_cache = False # 🔥 这样就不会影响 `TrainingArguments`,但依然禁用了 `use_cache`
# 初始化 Trainer
trainer = DistillationTrainer(
teacher=teacher,
model=student,
args=training_args,
train_dataset=dataset,
eval_dataset=dataset
)
# 开始训练
trainer.train()
# **保存模型**
student.push_to_hub("Snow2222/fst-nnn", use_auth_token=hf_token)
tokenizer.push_to_hub("Snow2222/fst-nnn", use_auth_token=hf_token)
# ✅ 部署 Gradio Web 界面
print("🎉 训练完成,启动 Gradio Web 界面...")
# **模型 ID**
model_id = "Snow2222/fst-nnn"
# **选择设备**
device = "cuda" if torch.cuda.is_available() else "cpu"
# **加载模型**
print("🚀 正在加载模型...")
model = AutoModelForCausalLM.from_pretrained(
model_id,
token=hf_token,
trust_remote_code=True
).to(device)
tokenizer = AutoTokenizer.from_pretrained(model_id, token=hf_token)
print("✅ 模型加载成功!")
# **Gradio 交互函数**
def chat_response(prompt):
chat_input = f"<|im_start|>system\n你是一个粉丝通软件的智能客服助手。\n<|im_end|>\n<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n"
inputs = tokenizer(chat_input, return_tensors="pt").to(device)
outputs = model.generate(
**inputs,
max_length=100, # ✅ 控制生成长度
do_sample=True,
temperature=0.7
)
result = tokenizer.decode(outputs[0], skip_special_tokens=True)
return result
# **创建 Gradio 界面**
iface = gr.Interface(
fn=chat_response,
inputs=gr.Textbox(lines=2, placeholder="请输入你的问题..."),
outputs="text",
title="粉丝通 AI 客服",
description="基于 Snow2222/fst-nnn 训练的 AI 模型,自动回答你的问题。",
allow_flagging="never"
)
# **运行 Gradio**
iface.launch(server_name="0.0.0.0", server_port=7860)
|