Spaces:
Sleeping
Sleeping
| import requests | |
| import pandas as pd | |
| import time | |
| import os | |
| INPUT_CSV = "dataset_test.csv" | |
| OUTPUT_CSV = "augmented_dataset2.csv" | |
| API_URL = "http://127.0.0.1:8000" | |
| # Hyperparameters (can be modified for subsequent tests) | |
| EDA_P = 0.15 | |
| SBERT_THRESH = 0.70 | |
| # Complete list of configurations: (Method, Pivot_Language) | |
| METHODS_TO_TEST = [ | |
| ("EDA", "none"), # Lexical perturbations (pivot ignored) | |
| ("BT", "en"), # Back-Translation via English | |
| ("BT", "de"), # Back-Translation via German | |
| ("BT", "cs"), # Back-Translation via Czech | |
| ("LLM", "none") # Generative LLM paraphrasing (pivot ignored) | |
| ] | |
| def run_mass_augmentation(): | |
| print(f"Starting experiment for {len(METHODS_TO_TEST)} variants. S-BERT Threshold: {SBERT_THRESH*100}%") | |
| # Check if input file exists | |
| if not os.path.exists(INPUT_CSV): | |
| print(f"CRITICAL ERROR: Input file {INPUT_CSV} not found.") | |
| return | |
| df = pd.read_csv(INPUT_CSV) | |
| augmented_pool = [] | |
| try: | |
| for index, row in df.iterrows(): | |
| original_text = row['text'] | |
| label = row['label'] | |
| print(f"\n[{index+1}/{len(df)}] Processing: {original_text[:40]}...") | |
| for method, pivot in METHODS_TO_TEST: | |
| method_name = f"{method}_{pivot}" if method == "BT" else method | |
| try: | |
| # 1. GENERATION PHASE | |
| payload = { | |
| "text": original_text, | |
| "method": method, | |
| "pivot_lang": pivot, | |
| "eda_p": EDA_P | |
| } | |
| aug_response = requests.post(f"{API_URL}/augment", json=payload) | |
| # Check if the server responded with 200 OK status | |
| if aug_response.status_code != 200: | |
| error_msg = aug_response.text | |
| print(f" API Rejected ({aug_response.status_code}) for {method_name}") | |
| # Defensive mechanism against API rate limits (e.g., Groq Cloud) | |
| if "429" in error_msg or "502" in error_msg or "rate limit" in error_msg.lower(): | |
| print(" API Rate limit reached. Sleeping for 60 seconds...") | |
| time.sleep(60) | |
| continue | |
| aug_data = aug_response.json() | |
| augmented_text = aug_data.get("augmented") | |
| if not augmented_text or augmented_text in ["BŁĄD", "ERROR"]: | |
| continue | |
| # 2. S-BERT SEMANTIC FILTRATION PHASE | |
| filter_payload = { | |
| "original": original_text, | |
| "augmented": augmented_text, | |
| "threshold": SBERT_THRESH | |
| } | |
| filt_res = requests.post(f"{API_URL}/filter", json=filter_payload).json() | |
| sim_score = filt_res.get("similarity", 0) | |
| passed = filt_res.get("passed", False) | |
| # 3. MEMORY BUFFER ALLOCATION | |
| if passed: | |
| print(f" {method_name.ljust(6)} ACCEPTED (Sim: {sim_score:.3f})") | |
| augmented_pool.append({ | |
| "id": f"{row['id']}_aug_{method_name}", | |
| "text": augmented_text, | |
| "label": label, | |
| "is_synthetic": True, | |
| "source_method": method_name, | |
| "similarity_score": sim_score | |
| }) | |
| else: | |
| print(f" {method_name.ljust(6)} REJECTED (Sim: {sim_score:.3f})") | |
| except requests.exceptions.RequestException as e: | |
| print(f" Network connection error for {method_name}: {e}") | |
| # Safe time delay between API calls to prevent throttling | |
| time.sleep(1.5) if method == "LLM" else time.sleep(0.3) | |
| except KeyboardInterrupt: | |
| print("\n\nMANUALLY INTERRUPTED! Stopping execution, initiating emergency data save...") | |
| # === DATA PERSISTENCE BLOCK === | |
| if augmented_pool: | |
| print("\nWriting generated samples to disk...") | |
| new_data_df = pd.DataFrame(augmented_pool) | |
| try: | |
| # If the file exists, append to it; otherwise, create a new one with baseline data | |
| if os.path.exists(OUTPUT_CSV): | |
| existing_df = pd.read_csv(OUTPUT_CSV) | |
| final_dataset = pd.concat([existing_df, new_data_df], ignore_index=True) | |
| # Remove any accidental duplicates based on ID | |
| final_dataset = final_dataset.drop_duplicates(subset=['id'], keep='last') | |
| else: | |
| # If no previous file exists, initialize with the baseline dataset | |
| baseline_df = df.copy() | |
| baseline_df['is_synthetic'] = False | |
| baseline_df['source_method'] = 'ORIGINAL' | |
| baseline_df['similarity_score'] = 1.0 | |
| final_dataset = pd.concat([baseline_df, new_data_df], ignore_index=True) | |
| # Save the final dataset | |
| final_dataset.to_csv(OUTPUT_CSV, index=False, encoding='utf-8') | |
| print(f"Success! Saved {len(new_data_df)} new samples.") | |
| print(f"Total dataset volume: {len(final_dataset)} records.") | |
| except Exception as e: | |
| print(f"CRITICAL: Failed to save data. Error: {e}") | |
| else: | |
| print("\nNo new accepted samples to serialize.") | |
| if __name__ == "__main__": | |
| run_mass_augmentation() |