File size: 5,247 Bytes
af9853e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
import os
import pandas as pd
from datasets import load_dataset, Dataset, concatenate_datasets, load_from_disk
from .config import Config

class DataProcessor:
    def __init__(self, tokenizer):
        self.tokenizer = tokenizer

    def load_clap_data(self):
        """
        加载 clapAI/MultiLingualSentiment 数据集的中文部分
        """
        print("Loading clapAI/MultiLingualSentiment (zh)...")
        try:
            # 假设数据集结构支持 language='zh' 筛选,或者我们加载后筛选
            # 注意:实际使用时可能需要根据具体 Hugging Face dataset 的 config name 调整
            ds = load_dataset("clapAI/MultiLingualSentiment", "zh", split="train", trust_remote_code=True)
        except Exception:
            # Fallback if specific config not found, load all and filter (demo logic)
            print("Warning: Could not load 'zh' specific config, attempting to load generic...")
            ds = load_dataset("clapAI/MultiLingualSentiment", split="train", trust_remote_code=True)
            ds = ds.filter(lambda x: x['language'] == 'zh')
        
        # 映射标签 (假设原标签格式需要调整,这里做通用处理)
        # 假设原数据集 label已经是 0,1,2 或者需要 map
        # 这里为了演示,我们假设它已经是标准格式,或者我们需要查看数据结构
        # 为保证稳健性,我们在 map_function 中处理
        return ds

    def load_medical_data(self):
        """
        加载 OpenModels/Chinese-Herbal-Medicine-Sentiment 垂直领域数据
        """
        print("Loading OpenModels/Chinese-Herbal-Medicine-Sentiment...")
        ds = load_dataset("OpenModels/Chinese-Herbal-Medicine-Sentiment", split="train", trust_remote_code=True)
        return ds

    def clean_data(self, examples):
        """
        数据清洗逻辑
        """
        text = examples['text']
        
        # 1. 剔除“默认好评”噪音
        if "此用户未填写评价内容" in text:
            return False
            
        # 简单长度过滤,太短的可能无意义
        if len(text.strip()) < 2:
            return False
            
        return True

    def unify_labels(self, example):
        """
        统一标签为: 0 (Negative), 1 (Neutral), 2 (Positive)
        """
        label = example['label']
        
        # 根据数据集实际情况调整映射逻辑
        # 这里假设传入的数据集 label 可能是 string 或 int
        # 这是一个示例映射,实际运行时需根据 print(ds.features) 确认
        if isinstance(label, str):
            label = label.lower()
            if label in ['negative', 'pos', '0']: # 示例
                return {'labels': 0}
            elif label in ['neutral', 'neu', '1']:
                return {'labels': 1}
            elif label in ['positive', 'neg', '2']:
                return {'labels': 2}
        
        # 如果已经是 int,确保在 0-2 之间
        return {'labels': int(label)}

    def tokenize_function(self, examples):
        return self.tokenizer(
            examples['text'],
            padding="max_length",
            truncation=True,
            max_length=Config.MAX_LENGTH
        )

    def get_processed_dataset(self, cache_dir=None, num_proc=1):
        # 默认使用 Config.DATA_DIR 作为缓存目录
        if cache_dir is None:
            cache_dir = Config.DATA_DIR

        # 0. 尝试从本地加载已处理的数据
        processed_path = os.path.join(cache_dir, "processed_dataset")
        if os.path.exists(processed_path):
            print(f"Loading processed dataset from {processed_path}...")
            return load_from_disk(processed_path)

        # 1. 加载数据
        ds_clap = self.load_clap_data()
        ds_med = self.load_medical_data()
        
        # 2. 统一列名 (确保都有 'text' 和 'label')
        # OpenModels keys: ['username', 'user_id', 'review_text', 'review_time', 'rating', 'product_id', 'sentiment_label', 'source_file']
        if 'review_text' in ds_med.column_names:
            ds_med = ds_med.rename_column('review_text', 'text')
        if 'sentiment_label' in ds_med.column_names:
            ds_med = ds_med.rename_column('sentiment_label', 'label')
        
        # 3. 数据清洗
        print("Cleaning datasets...")
        ds_med = ds_med.filter(self.clean_data)
        ds_clap = ds_clap.filter(self.clean_data)
        
        # 4. 合并
        # 确保 features 一致
        common_cols = ['text', 'label']
        ds_clap = ds_clap.select_columns(common_cols)
        ds_med = ds_med.select_columns(common_cols)
        
        combined_ds = concatenate_datasets([ds_clap, ds_med])
        
        # 5.标签处理 & Tokenization
        # transform label -> labels
        combined_ds = combined_ds.map(self.unify_labels, remove_columns=['label'])
        
        # tokenize and remove text
        tokenized_ds = combined_ds.map(
            self.tokenize_function, 
            batched=True, 
            remove_columns=['text']
        )
        
        # 划分训练集和验证集
        split_ds = tokenized_ds.train_test_split(test_size=0.1)
        
        return split_ds