import requests import pandas as pd import time import os INPUT_CSV = "dataset_test.csv" OUTPUT_CSV = "augmented_dataset2.csv" API_URL = "http://127.0.0.1:8000" # Hyperparameters (can be modified for subsequent tests) EDA_P = 0.15 SBERT_THRESH = 0.70 # Complete list of configurations: (Method, Pivot_Language) METHODS_TO_TEST = [ ("EDA", "none"), # Lexical perturbations (pivot ignored) ("BT", "en"), # Back-Translation via English ("BT", "de"), # Back-Translation via German ("BT", "cs"), # Back-Translation via Czech ("LLM", "none") # Generative LLM paraphrasing (pivot ignored) ] def run_mass_augmentation(): print(f"Starting experiment for {len(METHODS_TO_TEST)} variants. S-BERT Threshold: {SBERT_THRESH*100}%") # Check if input file exists if not os.path.exists(INPUT_CSV): print(f"CRITICAL ERROR: Input file {INPUT_CSV} not found.") return df = pd.read_csv(INPUT_CSV) augmented_pool = [] try: for index, row in df.iterrows(): original_text = row['text'] label = row['label'] print(f"\n[{index+1}/{len(df)}] Processing: {original_text[:40]}...") for method, pivot in METHODS_TO_TEST: method_name = f"{method}_{pivot}" if method == "BT" else method try: # 1. GENERATION PHASE payload = { "text": original_text, "method": method, "pivot_lang": pivot, "eda_p": EDA_P } aug_response = requests.post(f"{API_URL}/augment", json=payload) # Check if the server responded with 200 OK status if aug_response.status_code != 200: error_msg = aug_response.text print(f" API Rejected ({aug_response.status_code}) for {method_name}") # Defensive mechanism against API rate limits (e.g., Groq Cloud) if "429" in error_msg or "502" in error_msg or "rate limit" in error_msg.lower(): print(" API Rate limit reached. Sleeping for 60 seconds...") time.sleep(60) continue aug_data = aug_response.json() augmented_text = aug_data.get("augmented") if not augmented_text or augmented_text in ["BŁĄD", "ERROR"]: continue # 2. S-BERT SEMANTIC FILTRATION PHASE filter_payload = { "original": original_text, "augmented": augmented_text, "threshold": SBERT_THRESH } filt_res = requests.post(f"{API_URL}/filter", json=filter_payload).json() sim_score = filt_res.get("similarity", 0) passed = filt_res.get("passed", False) # 3. MEMORY BUFFER ALLOCATION if passed: print(f" {method_name.ljust(6)} ACCEPTED (Sim: {sim_score:.3f})") augmented_pool.append({ "id": f"{row['id']}_aug_{method_name}", "text": augmented_text, "label": label, "is_synthetic": True, "source_method": method_name, "similarity_score": sim_score }) else: print(f" {method_name.ljust(6)} REJECTED (Sim: {sim_score:.3f})") except requests.exceptions.RequestException as e: print(f" Network connection error for {method_name}: {e}") # Safe time delay between API calls to prevent throttling time.sleep(1.5) if method == "LLM" else time.sleep(0.3) except KeyboardInterrupt: print("\n\nMANUALLY INTERRUPTED! Stopping execution, initiating emergency data save...") # === DATA PERSISTENCE BLOCK === if augmented_pool: print("\nWriting generated samples to disk...") new_data_df = pd.DataFrame(augmented_pool) try: # If the file exists, append to it; otherwise, create a new one with baseline data if os.path.exists(OUTPUT_CSV): existing_df = pd.read_csv(OUTPUT_CSV) final_dataset = pd.concat([existing_df, new_data_df], ignore_index=True) # Remove any accidental duplicates based on ID final_dataset = final_dataset.drop_duplicates(subset=['id'], keep='last') else: # If no previous file exists, initialize with the baseline dataset baseline_df = df.copy() baseline_df['is_synthetic'] = False baseline_df['source_method'] = 'ORIGINAL' baseline_df['similarity_score'] = 1.0 final_dataset = pd.concat([baseline_df, new_data_df], ignore_index=True) # Save the final dataset final_dataset.to_csv(OUTPUT_CSV, index=False, encoding='utf-8') print(f"Success! Saved {len(new_data_df)} new samples.") print(f"Total dataset volume: {len(final_dataset)} records.") except Exception as e: print(f"CRITICAL: Failed to save data. Error: {e}") else: print("\nNo new accepted samples to serialize.") if __name__ == "__main__": run_mass_augmentation()