Data_augmentation / backend /run_experiments.py
Jacek Dusza
Initial commit: NLP Pipeline backend and React frontend
69a2c97
Raw
History Blame Contribute Delete
5.9 kB
import requests
import pandas as pd
import time
import os
INPUT_CSV = "dataset_test.csv"
OUTPUT_CSV = "augmented_dataset2.csv"
API_URL = "http://127.0.0.1:8000"
# Hyperparameters (can be modified for subsequent tests)
EDA_P = 0.15
SBERT_THRESH = 0.70
# Complete list of configurations: (Method, Pivot_Language)
METHODS_TO_TEST = [
("EDA", "none"), # Lexical perturbations (pivot ignored)
("BT", "en"), # Back-Translation via English
("BT", "de"), # Back-Translation via German
("BT", "cs"), # Back-Translation via Czech
("LLM", "none") # Generative LLM paraphrasing (pivot ignored)
]
def run_mass_augmentation():
print(f"Starting experiment for {len(METHODS_TO_TEST)} variants. S-BERT Threshold: {SBERT_THRESH*100}%")
# Check if input file exists
if not os.path.exists(INPUT_CSV):
print(f"CRITICAL ERROR: Input file {INPUT_CSV} not found.")
return
df = pd.read_csv(INPUT_CSV)
augmented_pool = []
try:
for index, row in df.iterrows():
original_text = row['text']
label = row['label']
print(f"\n[{index+1}/{len(df)}] Processing: {original_text[:40]}...")
for method, pivot in METHODS_TO_TEST:
method_name = f"{method}_{pivot}" if method == "BT" else method
try:
# 1. GENERATION PHASE
payload = {
"text": original_text,
"method": method,
"pivot_lang": pivot,
"eda_p": EDA_P
}
aug_response = requests.post(f"{API_URL}/augment", json=payload)
# Check if the server responded with 200 OK status
if aug_response.status_code != 200:
error_msg = aug_response.text
print(f" API Rejected ({aug_response.status_code}) for {method_name}")
# Defensive mechanism against API rate limits (e.g., Groq Cloud)
if "429" in error_msg or "502" in error_msg or "rate limit" in error_msg.lower():
print(" API Rate limit reached. Sleeping for 60 seconds...")
time.sleep(60)
continue
aug_data = aug_response.json()
augmented_text = aug_data.get("augmented")
if not augmented_text or augmented_text in ["BŁĄD", "ERROR"]:
continue
# 2. S-BERT SEMANTIC FILTRATION PHASE
filter_payload = {
"original": original_text,
"augmented": augmented_text,
"threshold": SBERT_THRESH
}
filt_res = requests.post(f"{API_URL}/filter", json=filter_payload).json()
sim_score = filt_res.get("similarity", 0)
passed = filt_res.get("passed", False)
# 3. MEMORY BUFFER ALLOCATION
if passed:
print(f" {method_name.ljust(6)} ACCEPTED (Sim: {sim_score:.3f})")
augmented_pool.append({
"id": f"{row['id']}_aug_{method_name}",
"text": augmented_text,
"label": label,
"is_synthetic": True,
"source_method": method_name,
"similarity_score": sim_score
})
else:
print(f" {method_name.ljust(6)} REJECTED (Sim: {sim_score:.3f})")
except requests.exceptions.RequestException as e:
print(f" Network connection error for {method_name}: {e}")
# Safe time delay between API calls to prevent throttling
time.sleep(1.5) if method == "LLM" else time.sleep(0.3)
except KeyboardInterrupt:
print("\n\nMANUALLY INTERRUPTED! Stopping execution, initiating emergency data save...")
# === DATA PERSISTENCE BLOCK ===
if augmented_pool:
print("\nWriting generated samples to disk...")
new_data_df = pd.DataFrame(augmented_pool)
try:
# If the file exists, append to it; otherwise, create a new one with baseline data
if os.path.exists(OUTPUT_CSV):
existing_df = pd.read_csv(OUTPUT_CSV)
final_dataset = pd.concat([existing_df, new_data_df], ignore_index=True)
# Remove any accidental duplicates based on ID
final_dataset = final_dataset.drop_duplicates(subset=['id'], keep='last')
else:
# If no previous file exists, initialize with the baseline dataset
baseline_df = df.copy()
baseline_df['is_synthetic'] = False
baseline_df['source_method'] = 'ORIGINAL'
baseline_df['similarity_score'] = 1.0
final_dataset = pd.concat([baseline_df, new_data_df], ignore_index=True)
# Save the final dataset
final_dataset.to_csv(OUTPUT_CSV, index=False, encoding='utf-8')
print(f"Success! Saved {len(new_data_df)} new samples.")
print(f"Total dataset volume: {len(final_dataset)} records.")
except Exception as e:
print(f"CRITICAL: Failed to save data. Error: {e}")
else:
print("\nNo new accepted samples to serialize.")
if __name__ == "__main__":
run_mass_augmentation()