File size: 3,870 Bytes
1d70196
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import xml.etree.ElementTree as ET
import pandas as pd
import os
import argparse
import requests

ASPECT_CATEGORIES = ['food', 'service', 'ambiance', 'price', 'anecdotes/miscellaneous']

def download_file(url: str, dest_path: str):
    print(f"Downloading from {url}...")
    response = requests.get(url)
    response.raise_for_status()
    with open(dest_path, 'wb') as f:
        f.write(response.content)
    print(f"Saved to {dest_path}")

def parse_semeval_xml(xml_path: str) -> pd.DataFrame:
    print(f"Parsing: {xml_path}")
    tree = ET.parse(xml_path)
    root = tree.getroot()
    
    rows = []
    
    for sentence in root.findall('.//sentence'):
        text_node = sentence.find('text')
        if text_node is None or not text_node.text:
            continue
            
        text = text_node.text.strip()
        
        aspects_node = sentence.find('aspectCategories')
        if aspects_node is not None:
            for aspect_cat in aspects_node.findall('aspectCategory'):
                category = aspect_cat.get('category')
                polarity = aspect_cat.get('polarity')
                
                if category in ASPECT_CATEGORIES and polarity in ['positive', 'negative', 'neutral', 'conflict']:
                    rows.append({
                        'text': text,
                        'aspect': category,
                        'sentiment': polarity
                    })
                    
    df = pd.DataFrame(rows)
    print(f"Extracted {len(df)} aspect-sentiment pairs from {len(df['text'].unique()) if len(df) > 0 else 0} unique sentences.")
    return df

def main():
    parser = argparse.ArgumentParser(description="Convert SemEval XML to CSV for RoBERTa fine-tuning.")
    parser.add_argument('--out_dir', type=str, default='data/processed', help='Directory to save CSVs')
    parser.add_argument('--raw_dir', type=str, default='data/raw', help='Directory to save raw XMLs')
    
    args = parser.parse_args()
    
    os.makedirs(args.out_dir, exist_ok=True)
    os.makedirs(args.raw_dir, exist_ok=True)
    
    # A reliable mirror for the ABSA dataset
    train_url = "https://s3.amazonaws.com/fast-ai-nlp/semeval2014_task4/Restaurants_Train_v2.xml"
    test_url = "https://s3.amazonaws.com/fast-ai-nlp/semeval2014_task4/Restaurants_Test_Gold.xml"
    
    train_xml = os.path.join(args.raw_dir, 'Restaurants_Train_v2.xml')
    test_xml = os.path.join(args.raw_dir, 'Restaurants_Test_Gold.xml')
    
    try:
        if os.path.exists('dataset/Restaurants_Train_v2.xml'):
            print("Found local dataset directory!")
            train_xml = 'dataset/Restaurants_Train_v2.xml'
        else:
            if not os.path.exists(train_xml):
                download_file(train_url, train_xml)
                
        # We only strictly require the training file since test phase B is missing 'polarity'
        full_df = parse_semeval_xml(train_xml)
        
        from sklearn.model_selection import train_test_split
        train_df, test_df = train_test_split(full_df, test_size=0.2, random_state=42)
        
        out_path_train = os.path.join(args.out_dir, 'train.csv')
        train_df.to_csv(out_path_train, index=False)
        print(f"Saved training data ({len(train_df)} rows) to: {out_path_train}")
        
        out_path_test = os.path.join(args.out_dir, 'test.csv')
        test_df.to_csv(out_path_test, index=False)
        print(f"Saved testing data ({len(test_df)} rows) to: {out_path_test}")
        
        print("\nData processing complete! You can now upload train.csv and test.csv to Colab.")
        
    except requests.exceptions.RequestException as e:
         print(f"Failed to download dataset: {e}")
         print("Please manually place 'Restaurants_Train_v2.xml' and 'Restaurants_Test_Gold.xml' in the data/raw folder.")

if __name__ == "__main__":
    main()