File size: 5,903 Bytes
69a2c97
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
import requests
import pandas as pd
import time
import os

INPUT_CSV = "dataset_test.csv" 
OUTPUT_CSV = "augmented_dataset2.csv"
API_URL = "http://127.0.0.1:8000"

# Hyperparameters (can be modified for subsequent tests)
EDA_P = 0.15          
SBERT_THRESH = 0.70   

# Complete list of configurations: (Method, Pivot_Language)
METHODS_TO_TEST = [
    ("EDA", "none"), # Lexical perturbations (pivot ignored)
    ("BT", "en"),    # Back-Translation via English
    ("BT", "de"),    # Back-Translation via German
    ("BT", "cs"),    # Back-Translation via Czech
    ("LLM", "none")  # Generative LLM paraphrasing (pivot ignored)
]

def run_mass_augmentation():
    print(f"Starting experiment for {len(METHODS_TO_TEST)} variants. S-BERT Threshold: {SBERT_THRESH*100}%")
    
    # Check if input file exists
    if not os.path.exists(INPUT_CSV):
        print(f"CRITICAL ERROR: Input file {INPUT_CSV} not found.")
        return

    df = pd.read_csv(INPUT_CSV)
    augmented_pool = []
    
    try:
        for index, row in df.iterrows():
            original_text = row['text']
            label = row['label']
            print(f"\n[{index+1}/{len(df)}] Processing: {original_text[:40]}...")
            
            for method, pivot in METHODS_TO_TEST:
                method_name = f"{method}_{pivot}" if method == "BT" else method
                
                try:
                    # 1. GENERATION PHASE
                    payload = {
                        "text": original_text, 
                        "method": method, 
                        "pivot_lang": pivot, 
                        "eda_p": EDA_P
                    }
                    aug_response = requests.post(f"{API_URL}/augment", json=payload)
                    
                    # Check if the server responded with 200 OK status
                    if aug_response.status_code != 200:
                        error_msg = aug_response.text
                        print(f"  API Rejected ({aug_response.status_code}) for {method_name}")
                        
                        # Defensive mechanism against API rate limits (e.g., Groq Cloud)
                        if "429" in error_msg or "502" in error_msg or "rate limit" in error_msg.lower():
                            print("  API Rate limit reached. Sleeping for 60 seconds...")
                            time.sleep(60) 
                        continue 

                    aug_data = aug_response.json()
                    augmented_text = aug_data.get("augmented")
                    
                    if not augmented_text or augmented_text in ["BŁĄD", "ERROR"]:
                        continue

                    # 2. S-BERT SEMANTIC FILTRATION PHASE
                    filter_payload = {
                        "original": original_text, 
                        "augmented": augmented_text, 
                        "threshold": SBERT_THRESH
                    }
                    filt_res = requests.post(f"{API_URL}/filter", json=filter_payload).json()
                    
                    sim_score = filt_res.get("similarity", 0)
                    passed = filt_res.get("passed", False)
                    
                    # 3. MEMORY BUFFER ALLOCATION
                    if passed:
                        print(f"  {method_name.ljust(6)} ACCEPTED (Sim: {sim_score:.3f})")
                        augmented_pool.append({
                            "id": f"{row['id']}_aug_{method_name}",
                            "text": augmented_text,
                            "label": label,
                            "is_synthetic": True,
                            "source_method": method_name,
                            "similarity_score": sim_score
                        })
                    else:
                        print(f"  {method_name.ljust(6)} REJECTED (Sim: {sim_score:.3f})")
                    
                except requests.exceptions.RequestException as e:
                    print(f"  Network connection error for {method_name}: {e}")
                    
                # Safe time delay between API calls to prevent throttling
                time.sleep(1.5) if method == "LLM" else time.sleep(0.3)
                
    except KeyboardInterrupt:
        print("\n\nMANUALLY INTERRUPTED! Stopping execution, initiating emergency data save...")
        
    # === DATA PERSISTENCE BLOCK ===
    if augmented_pool:
        print("\nWriting generated samples to disk...")
        new_data_df = pd.DataFrame(augmented_pool)
        
        try:
            # If the file exists, append to it; otherwise, create a new one with baseline data
            if os.path.exists(OUTPUT_CSV):
                existing_df = pd.read_csv(OUTPUT_CSV)
                final_dataset = pd.concat([existing_df, new_data_df], ignore_index=True)
                # Remove any accidental duplicates based on ID
                final_dataset = final_dataset.drop_duplicates(subset=['id'], keep='last')
            else:
                # If no previous file exists, initialize with the baseline dataset
                baseline_df = df.copy()
                baseline_df['is_synthetic'] = False
                baseline_df['source_method'] = 'ORIGINAL'
                baseline_df['similarity_score'] = 1.0
                final_dataset = pd.concat([baseline_df, new_data_df], ignore_index=True)
            
            # Save the final dataset
            final_dataset.to_csv(OUTPUT_CSV, index=False, encoding='utf-8')
            
            print(f"Success! Saved {len(new_data_df)} new samples.")
            print(f"Total dataset volume: {len(final_dataset)} records.")
        except Exception as e:
            print(f"CRITICAL: Failed to save data. Error: {e}")
    else:
        print("\nNo new accepted samples to serialize.")

if __name__ == "__main__":
    run_mass_augmentation()