Spaces:
Runtime error
Runtime error
| import json | |
| import torch | |
| from transformers import AutoTokenizer, AutoModelForSequenceClassification, Trainer, TrainingArguments | |
| from sklearn.model_selection import train_test_split | |
| from datasets import Dataset | |
| import gradio as gr | |
| # Step 1: 加载数据 | |
| DATA_FILE = "translation model training data_major_strategy.json" # 数据文件名 | |
| # 读取 JSON 数据文件 | |
| with open(DATA_FILE, "r", encoding="utf-8") as f: | |
| data = json.load(f) | |
| # 数据预处理:拼接文本和生成标签 | |
| texts = [f"{item['source']} [SEP] {item['translation']}" for item in data] | |
| # 三种策略:创译=0,仿译=1,创仿=2 | |
| label_map = {"创译": 0, "仿译": 1, "创仿": 2} | |
| labels = [label_map[item['major_strategy']] for item in data] | |
| # 划分训练集和验证集 | |
| train_texts, val_texts, train_labels, val_labels = train_test_split(texts, labels, test_size=0.2, random_state=42) | |
| # Step 2: 加载分词器和模型 | |
| MODEL_NAME = "sentence-transformers/LaBSE" | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) | |
| # 分词函数 | |
| def tokenize_function(texts): | |
| return tokenizer(texts, padding="max_length", truncation=True, max_length=128) | |
| train_encodings = tokenize_function(train_texts) | |
| val_encodings = tokenize_function(val_texts) | |
| # 转换为 Hugging Face Dataset 格式 | |
| train_dataset = Dataset.from_dict({ | |
| "input_ids": train_encodings["input_ids"], | |
| "attention_mask": train_encodings["attention_mask"], | |
| "labels": train_labels | |
| }) | |
| val_dataset = Dataset.from_dict({ | |
| "input_ids": val_encodings["input_ids"], | |
| "attention_mask": val_encodings["attention_mask"], | |
| "labels": val_labels | |
| }) | |
| # 加载 LaBSE 模型,添加分类头(num_labels=3,适配三分类任务) | |
| model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME, num_labels=3) | |
| # Step 3: 设置训练参数 | |
| training_args = TrainingArguments( | |
| output_dir="./results", # 模型保存路径 | |
| eval_strategy="epoch", # 使用 eval_strategy 替代 evaluation_strategy | |
| save_strategy="epoch", # 保存策略和评估策略一致 | |
| learning_rate=2e-5, # 学习率 | |
| per_device_train_batch_size=8, # 每设备的训练 batch size | |
| per_device_eval_batch_size=8, # 每设备的验证 batch size | |
| num_train_epochs=3, # 训练轮数 | |
| weight_decay=0.01, # 权重衰减 | |
| save_total_limit=1, # 只保存一个最优模型 | |
| load_best_model_at_end=True, # 加载验证集性能最优的模型 | |
| logging_dir="./logs", # 日志路径 | |
| logging_steps=10 # 日志记录间隔 | |
| ) | |
| # 自定义评估指标 | |
| def compute_metrics(pred): | |
| from sklearn.metrics import accuracy_score, precision_recall_fscore_support | |
| labels = pred.label_ids | |
| preds = pred.predictions.argmax(-1) | |
| precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average='weighted') | |
| acc = accuracy_score(labels, preds) | |
| return {"accuracy": acc, "f1": f1, "precision": precision, "recall": recall} | |
| # 定义 Trainer | |
| trainer = Trainer( | |
| model=model, | |
| args=training_args, | |
| train_dataset=train_dataset, | |
| eval_dataset=val_dataset, | |
| tokenizer=tokenizer, | |
| compute_metrics=compute_metrics | |
| ) | |
| # Step 4: 开始训练 | |
| trainer.train() | |
| # 保存微调后的模型 | |
| model.save_pretrained("./trained_labse_model") | |
| tokenizer.save_pretrained("./trained_labse_model") | |
| # Step 5: 推理服务 | |
| def predict_strategy(source, translation): | |
| """预测翻译策略""" | |
| text = f"{source} [SEP] {translation}" | |
| inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=128) | |
| outputs = model(**inputs) | |
| predicted_class = torch.argmax(outputs.logits, dim=1).item() | |
| strategy_map = {0: "创译", 1: "仿译", 2: "创仿"} | |
| return strategy_map[predicted_class] | |
| # 使用 Gradio 构建 Web 界面 | |
| interface = gr.Interface( | |
| fn=predict_strategy, | |
| inputs=["text", "text"], | |
| outputs="text", | |
| title="Translation Strategy Classifier", | |
| description="输入中文原文和英文译文,预测翻译策略(创译/仿译/创仿)。", | |
| examples=[ | |
| ["扛紧制度的笼箍", "Reinforce relevant institutions"], | |
| ["中国发展的巨轮", "Our country continues to progress steadily"], | |
| ["发挥巡视利剑作用", "Let discipline inspection cut through corruption like a blade."] | |
| ] | |
| ) | |
| # 启动 Gradio 应用 | |
| if __name__ == "__main__": | |
| interface.launch() | |