Spaces:
Sleeping
Sleeping
File size: 5,903 Bytes
69a2c97 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 | import requests
import pandas as pd
import time
import os
INPUT_CSV = "dataset_test.csv"
OUTPUT_CSV = "augmented_dataset2.csv"
API_URL = "http://127.0.0.1:8000"
# Hyperparameters (can be modified for subsequent tests)
EDA_P = 0.15
SBERT_THRESH = 0.70
# Complete list of configurations: (Method, Pivot_Language)
METHODS_TO_TEST = [
("EDA", "none"), # Lexical perturbations (pivot ignored)
("BT", "en"), # Back-Translation via English
("BT", "de"), # Back-Translation via German
("BT", "cs"), # Back-Translation via Czech
("LLM", "none") # Generative LLM paraphrasing (pivot ignored)
]
def run_mass_augmentation():
print(f"Starting experiment for {len(METHODS_TO_TEST)} variants. S-BERT Threshold: {SBERT_THRESH*100}%")
# Check if input file exists
if not os.path.exists(INPUT_CSV):
print(f"CRITICAL ERROR: Input file {INPUT_CSV} not found.")
return
df = pd.read_csv(INPUT_CSV)
augmented_pool = []
try:
for index, row in df.iterrows():
original_text = row['text']
label = row['label']
print(f"\n[{index+1}/{len(df)}] Processing: {original_text[:40]}...")
for method, pivot in METHODS_TO_TEST:
method_name = f"{method}_{pivot}" if method == "BT" else method
try:
# 1. GENERATION PHASE
payload = {
"text": original_text,
"method": method,
"pivot_lang": pivot,
"eda_p": EDA_P
}
aug_response = requests.post(f"{API_URL}/augment", json=payload)
# Check if the server responded with 200 OK status
if aug_response.status_code != 200:
error_msg = aug_response.text
print(f" API Rejected ({aug_response.status_code}) for {method_name}")
# Defensive mechanism against API rate limits (e.g., Groq Cloud)
if "429" in error_msg or "502" in error_msg or "rate limit" in error_msg.lower():
print(" API Rate limit reached. Sleeping for 60 seconds...")
time.sleep(60)
continue
aug_data = aug_response.json()
augmented_text = aug_data.get("augmented")
if not augmented_text or augmented_text in ["BŁĄD", "ERROR"]:
continue
# 2. S-BERT SEMANTIC FILTRATION PHASE
filter_payload = {
"original": original_text,
"augmented": augmented_text,
"threshold": SBERT_THRESH
}
filt_res = requests.post(f"{API_URL}/filter", json=filter_payload).json()
sim_score = filt_res.get("similarity", 0)
passed = filt_res.get("passed", False)
# 3. MEMORY BUFFER ALLOCATION
if passed:
print(f" {method_name.ljust(6)} ACCEPTED (Sim: {sim_score:.3f})")
augmented_pool.append({
"id": f"{row['id']}_aug_{method_name}",
"text": augmented_text,
"label": label,
"is_synthetic": True,
"source_method": method_name,
"similarity_score": sim_score
})
else:
print(f" {method_name.ljust(6)} REJECTED (Sim: {sim_score:.3f})")
except requests.exceptions.RequestException as e:
print(f" Network connection error for {method_name}: {e}")
# Safe time delay between API calls to prevent throttling
time.sleep(1.5) if method == "LLM" else time.sleep(0.3)
except KeyboardInterrupt:
print("\n\nMANUALLY INTERRUPTED! Stopping execution, initiating emergency data save...")
# === DATA PERSISTENCE BLOCK ===
if augmented_pool:
print("\nWriting generated samples to disk...")
new_data_df = pd.DataFrame(augmented_pool)
try:
# If the file exists, append to it; otherwise, create a new one with baseline data
if os.path.exists(OUTPUT_CSV):
existing_df = pd.read_csv(OUTPUT_CSV)
final_dataset = pd.concat([existing_df, new_data_df], ignore_index=True)
# Remove any accidental duplicates based on ID
final_dataset = final_dataset.drop_duplicates(subset=['id'], keep='last')
else:
# If no previous file exists, initialize with the baseline dataset
baseline_df = df.copy()
baseline_df['is_synthetic'] = False
baseline_df['source_method'] = 'ORIGINAL'
baseline_df['similarity_score'] = 1.0
final_dataset = pd.concat([baseline_df, new_data_df], ignore_index=True)
# Save the final dataset
final_dataset.to_csv(OUTPUT_CSV, index=False, encoding='utf-8')
print(f"Success! Saved {len(new_data_df)} new samples.")
print(f"Total dataset volume: {len(final_dataset)} records.")
except Exception as e:
print(f"CRITICAL: Failed to save data. Error: {e}")
else:
print("\nNo new accepted samples to serialize.")
if __name__ == "__main__":
run_mass_augmentation() |