DATAsoong's picture
Upload app.py
4afac1e verified
import json
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification, Trainer, TrainingArguments
from sklearn.model_selection import train_test_split
from datasets import Dataset
import gradio as gr
# Step 1: 加载数据
DATA_FILE = "translation model training data_major_strategy.json" # 数据文件名
# 读取 JSON 数据文件
with open(DATA_FILE, "r", encoding="utf-8") as f:
data = json.load(f)
# 数据预处理:拼接文本和生成标签
texts = [f"{item['source']} [SEP] {item['translation']}" for item in data]
# 三种策略:创译=0,仿译=1,创仿=2
label_map = {"创译": 0, "仿译": 1, "创仿": 2}
labels = [label_map[item['major_strategy']] for item in data]
# 划分训练集和验证集
train_texts, val_texts, train_labels, val_labels = train_test_split(texts, labels, test_size=0.2, random_state=42)
# Step 2: 加载分词器和模型
MODEL_NAME = "sentence-transformers/LaBSE"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
# 分词函数
def tokenize_function(texts):
return tokenizer(texts, padding="max_length", truncation=True, max_length=128)
train_encodings = tokenize_function(train_texts)
val_encodings = tokenize_function(val_texts)
# 转换为 Hugging Face Dataset 格式
train_dataset = Dataset.from_dict({
"input_ids": train_encodings["input_ids"],
"attention_mask": train_encodings["attention_mask"],
"labels": train_labels
})
val_dataset = Dataset.from_dict({
"input_ids": val_encodings["input_ids"],
"attention_mask": val_encodings["attention_mask"],
"labels": val_labels
})
# 加载 LaBSE 模型,添加分类头(num_labels=3,适配三分类任务)
model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME, num_labels=3)
# Step 3: 设置训练参数
training_args = TrainingArguments(
output_dir="./results", # 模型保存路径
eval_strategy="epoch", # 使用 eval_strategy 替代 evaluation_strategy
save_strategy="epoch", # 保存策略和评估策略一致
learning_rate=2e-5, # 学习率
per_device_train_batch_size=8, # 每设备的训练 batch size
per_device_eval_batch_size=8, # 每设备的验证 batch size
num_train_epochs=3, # 训练轮数
weight_decay=0.01, # 权重衰减
save_total_limit=1, # 只保存一个最优模型
load_best_model_at_end=True, # 加载验证集性能最优的模型
logging_dir="./logs", # 日志路径
logging_steps=10 # 日志记录间隔
)
# 自定义评估指标
def compute_metrics(pred):
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
labels = pred.label_ids
preds = pred.predictions.argmax(-1)
precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average='weighted')
acc = accuracy_score(labels, preds)
return {"accuracy": acc, "f1": f1, "precision": precision, "recall": recall}
# 定义 Trainer
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=val_dataset,
tokenizer=tokenizer,
compute_metrics=compute_metrics
)
# Step 4: 开始训练
trainer.train()
# 保存微调后的模型
model.save_pretrained("./trained_labse_model")
tokenizer.save_pretrained("./trained_labse_model")
# Step 5: 推理服务
def predict_strategy(source, translation):
"""预测翻译策略"""
text = f"{source} [SEP] {translation}"
inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=128)
outputs = model(**inputs)
predicted_class = torch.argmax(outputs.logits, dim=1).item()
strategy_map = {0: "创译", 1: "仿译", 2: "创仿"}
return strategy_map[predicted_class]
# 使用 Gradio 构建 Web 界面
interface = gr.Interface(
fn=predict_strategy,
inputs=["text", "text"],
outputs="text",
title="Translation Strategy Classifier",
description="输入中文原文和英文译文,预测翻译策略(创译/仿译/创仿)。",
examples=[
["扛紧制度的笼箍", "Reinforce relevant institutions"],
["中国发展的巨轮", "Our country continues to progress steadily"],
["发挥巡视利剑作用", "Let discipline inspection cut through corruption like a blade."]
]
)
# 启动 Gradio 应用
if __name__ == "__main__":
interface.launch()