File size: 7,248 Bytes
a8bcc65
 
0c4575f
eff8ebb
0c4575f
a8bcc65
e131b6c
51b81ca
65d3709
 
51b81ca
a5f3d74
a8bcc65
0c4575f
a8bcc65
 
 
e131b6c
93bf8ad
bb7b0e9
 
93bf8ad
 
 
0c4575f
93bf8ad
0c4575f
 
 
 
93bf8ad
 
0c4575f
93bf8ad
0c4575f
 
f5d7ee2
0c4575f
93bf8ad
f5d7ee2
93bf8ad
f5d7ee2
 
 
51b81ca
93bf8ad
a5f3d74
 
0b272a1
93bf8ad
a8bcc65
 
e131b6c
93bf8ad
51b81ca
 
 
b7e2d99
 
 
 
 
 
 
 
0c4575f
93bf8ad
b7e2d99
9c056a5
 
 
515c81d
9c056a5
 
 
515c81d
9c056a5
 
 
 
 
 
 
 
 
b7e2d99
515c81d
0c4575f
b7e2d99
93bf8ad
0c4575f
51b81ca
 
93bf8ad
51b81ca
eb81ebd
0c4575f
51b81ca
93bf8ad
0c4575f
 
e131b6c
93bf8ad
0c4575f
93bf8ad
bb7b0e9
0c4575f
 
 
93bf8ad
0c4575f
 
51b81ca
93bf8ad
 
 
 
 
 
 
0c4575f
 
93bf8ad
0c4575f
 
 
93bf8ad
0c4575f
 
 
 
e131b6c
93bf8ad
 
eb81ebd
93bf8ad
 
eb81ebd
 
51b81ca
5f81635
e131b6c
aa9f3ba
e131b6c
0c4575f
 
e131b6c
0c4575f
 
 
93bf8ad
 
5dc7c29
 
 
e131b6c
 
5f81635
 
 
51b81ca
0c4575f
93bf8ad
0c4575f
e131b6c
a8bcc65
93bf8ad
e131b6c
 
 
 
a8bcc65
93bf8ad
51b81ca
bb7b0e9
eff8ebb
 
 
 
b7e2d99
7e74c10
eff8ebb
b7e2d99
7e74c10
 
 
b7e2d99
 
 
 
 
 
 
 
 
 
7e74c10
 
eff8ebb
b7e2d99
 
 
 
 
 
 
 
7e74c10
eff8ebb
 
7e74c10
eff8ebb
 
7e74c10
eff8ebb
7e74c10
 
 
eff8ebb
 
7e74c10
eff8ebb
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
import os
import json
import torch
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer, Trainer, TrainingArguments
from datasets import Dataset

# 设置 Hugging Face 缓存目录
os.environ['HF_HOME'] = '/tmp/huggingface_cache'

# 读取 Hugging Face Token
hf_token = os.getenv('HF_TOKEN')  # 请确保环境变量 HF_TOKEN 已设置
if hf_token:
    from huggingface_hub import HfFolder
    HfFolder.save_token(hf_token)
else:
    raise ValueError("Hugging Face token 未设置")

# ✅ 统一设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ✅ **换成兼容的学生模型**
teacher_model_name = "Qwen/Qwen1.5-7B-Chat"  
student_model_name = "Qwen/Qwen1.5-1.8B-Chat"  # ✅ **换成 Qwen1.5 1.8B 版本**

# 加载教师模型
teacher = AutoModelForCausalLM.from_pretrained(
    teacher_model_name,
    trust_remote_code=True,
    token=hf_token
).to(device)
teacher.eval()

# 加载**学生模型**
student = AutoModelForCausalLM.from_pretrained(
    student_model_name,
    trust_remote_code=True,
    token=hf_token
).to(device)
tokenizer = AutoTokenizer.from_pretrained(
    student_model_name,  # ✅ **用 Qwen 词表,防止维度不匹配**
    trust_remote_code=True,
    token=hf_token
)

# **处理 pad_token 问题**
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

# 读取数据文件
with open('data.json', 'r', encoding='utf-8') as f:
    data = json.load(f)

# 检查数据格式
if not isinstance(data, list):
    raise ValueError("data.json 格式错误,需要是一个列表!")

# **格式化数据(加 `chat_template`)**
def format_chat(example):
    instruction = example["instruction"]
    output = example["output"]
    return f"<|im_start|>system\n你是一个粉丝通软件的智能客服助手。\n<|im_end|>\n<|im_start|>user\n{instruction}<|im_end|>\n<|im_start|>assistant\n{output}<|im_end|>"




# **数据预处理**
def preprocess_data(example):
    inputs = tokenizer(example["instruction"], truncation=True, padding="max_length", max_length=128)
    labels = tokenizer(example["output"], truncation=True, padding="max_length", max_length=128)

    return {
        "input_ids": inputs["input_ids"],
        "attention_mask": inputs["attention_mask"],
        "labels": labels["input_ids"]
    }

# def preprocess_data(example):
#     formatted_text = format_chat(example)
#     tokens = tokenizer(formatted_text, truncation=True, padding="max_length", max_length=128)
#     return {
#         "input_ids": tokens["input_ids"],
#         "attention_mask": tokens["attention_mask"],
#         "labels": tokens["input_ids"]
#     }
dataset = Dataset.from_list(data)
dataset = dataset.map(preprocess_data, batched=True)


# ✅ **修正 KL 散度计算**
class DistillationTrainer(Trainer):
    def __init__(self, teacher, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.teacher = teacher.to(device)

    def compute_loss(self, model, inputs, return_outputs=False):
        labels = inputs["input_ids"]

        # **学生模型的输出**
        outputs_student = model(**inputs)
        logits_student = outputs_student.logits

        # **教师模型的输出**
        with torch.no_grad():
            inputs_on_device = {k: v.to(device) for k, v in inputs.items()}
            outputs_teacher = self.teacher(**inputs_on_device)
            logits_teacher = outputs_teacher.logits

        temperature = 2.0
        # **修正维度**
        student_log_probs = torch.nn.functional.log_softmax(logits_student / temperature, dim=-1)
        teacher_probs = torch.nn.functional.softmax(logits_teacher / temperature, dim=-1)

        # ✅ **确保两个 logits 维度匹配**
        if student_log_probs.shape != teacher_probs.shape:
            min_dim = min(student_log_probs.shape[-1], teacher_probs.shape[-1])
            student_log_probs = student_log_probs[..., :min_dim]
            teacher_probs = teacher_probs[..., :min_dim]

        # **计算 KL 散度**
        kl_loss = torch.nn.functional.kl_div(student_log_probs, teacher_probs, reduction="batchmean") * (temperature ** 2)

        # **计算交叉熵损失**
        ce_loss_fct = torch.nn.CrossEntropyLoss()
        ce_loss = ce_loss_fct(logits_student.view(-1, logits_student.size(-1)), labels.view(-1))

        # **混合损失**
        alpha = 0.5
        loss = alpha * ce_loss + (1 - alpha) * kl_loss

        return (loss, outputs_student) if return_outputs else loss

    def training_step(self, model, inputs, *args, **kwargs):
        """确保所有输入都在 GPU 上"""
        model.train()
        inputs = {k: v.to(device) for k, v in self._prepare_inputs(inputs).items()}
        loss = self.compute_loss(model, inputs)
        return loss

# 训练参数
# ✅ 移除 `use_cache` 选项
training_args = TrainingArguments(
    output_dir="/tmp/distilled_model",
    num_train_epochs=3,
    per_device_train_batch_size=8,
    learning_rate=2e-5,
    weight_decay=0.01,
    evaluation_strategy="epoch",
    logging_steps=100,
    save_strategy="epoch",
    remove_unused_columns=False,
    gradient_checkpointing=True,
    fp16=False,  # ✅ 禁用 fp16,避免 _scale=None 错误
    bf16=True if torch.cuda.is_available() else False  # ✅ 仅在支持 bf16 时启用

)

# ✅ **手动禁用 `use_cache`**
student.config.use_cache = False  # 🔥 这样就不会影响 `TrainingArguments`,但依然禁用了 `use_cache`

# 初始化 Trainer
trainer = DistillationTrainer(
    teacher=teacher,
    model=student,
    args=training_args,
    train_dataset=dataset,
    eval_dataset=dataset
)

# 开始训练
trainer.train()

# **保存模型**
student.push_to_hub("Snow2222/fst-nnn", use_auth_token=hf_token)
tokenizer.push_to_hub("Snow2222/fst-nnn", use_auth_token=hf_token)

# ✅ 部署 Gradio Web 界面
print("🎉 训练完成,启动 Gradio Web 界面...")

# **模型 ID**
model_id = "Snow2222/fst-nnn"

# **选择设备**
device = "cuda" if torch.cuda.is_available() else "cpu"

# **加载模型**
print("🚀 正在加载模型...")
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    token=hf_token,
    trust_remote_code=True
).to(device)

tokenizer = AutoTokenizer.from_pretrained(model_id, token=hf_token)

print("✅ 模型加载成功!")

# **Gradio 交互函数**
def chat_response(prompt):
    chat_input = f"<|im_start|>system\n你是一个粉丝通软件的智能客服助手。\n<|im_end|>\n<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n"
    inputs = tokenizer(chat_input, return_tensors="pt").to(device)
    outputs = model.generate(
        **inputs,
        max_length=100,  # ✅ 控制生成长度
        do_sample=True,
        temperature=0.7
    )
    result = tokenizer.decode(outputs[0], skip_special_tokens=True)
    return result

# **创建 Gradio 界面**
iface = gr.Interface(
    fn=chat_response,
    inputs=gr.Textbox(lines=2, placeholder="请输入你的问题..."),
    outputs="text",
    title="粉丝通 AI 客服",
    description="基于 Snow2222/fst-nnn 训练的 AI 模型,自动回答你的问题。",
    allow_flagging="never"
)

# **运行 Gradio**
iface.launch(server_name="0.0.0.0", server_port=7860)