| import torch |
| from transformers import ( |
| AutoModelForCausalLM, |
| AutoTokenizer, |
| BitsAndBytesConfig |
| ) |
| import pandas as pd |
| import numpy as np |
| from tqdm import tqdm |
| from pathlib import Path |
| import logging |
| import gc |
| from typing import List, Dict |
| import json |
| from datetime import datetime |
| import time |
| import sys |
| from sklearn.feature_extraction.text import TfidfVectorizer |
| from sklearn.linear_model import LogisticRegression |
| import joblib |
| import random |
|
|
| |
| log_dir = Path("logs") |
| log_dir.mkdir(exist_ok=True) |
|
|
| |
| timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") |
| log_file = log_dir / f"generation_{timestamp}.log" |
|
|
| |
| logging.basicConfig( |
| level=logging.INFO, |
| format='%(asctime)s | %(message)s', |
| handlers=[ |
| logging.StreamHandler(sys.stdout), |
| logging.FileHandler(log_file) |
| ] |
| ) |
|
|
| logger = logging.getLogger(__name__) |
| logger.info(f"Starting new run. Log file: {log_file}") |
|
|
| class FastToxicValidator: |
| """Fast toxicity validation using logistic regression""" |
| def __init__(self, model_path: str = "weights/toxic_validator.joblib"): |
| self.model_path = model_path |
| if Path(model_path).exists(): |
| logger.info("Loading fast toxic validator...") |
| model_data = joblib.load(model_path) |
| self.vectorizers = model_data['vectorizers'] |
| self.models = model_data['models'] |
| logger.info("✓ Fast validator loaded") |
| else: |
| logger.info("Training fast toxic validator...") |
| self._train_validator() |
| logger.info("✓ Fast validator trained and saved") |
| |
| def _train_validator(self): |
| """Train logistic regression models for each toxicity type""" |
| |
| train_df = pd.read_csv("dataset/split/train.csv") |
| |
| |
| labels = ['toxic', 'severe_toxic', 'obscene', 'threat', 'insult', 'identity_hate'] |
| |
| self.vectorizers = {} |
| self.models = {} |
| |
| |
| for label in labels: |
| |
| vectorizer = TfidfVectorizer( |
| max_features=10000, |
| ngram_range=(1, 2), |
| strip_accents='unicode', |
| min_df=2 |
| ) |
| X = vectorizer.fit_transform(train_df['comment_text'].fillna('')) |
| y = train_df[label] |
| |
| |
| model = LogisticRegression( |
| C=1.0, |
| class_weight='balanced', |
| max_iter=200, |
| n_jobs=-1 |
| ) |
| model.fit(X, y) |
| |
| self.vectorizers[label] = vectorizer |
| self.models[label] = model |
| |
| |
| joblib.dump({ |
| 'vectorizers': self.vectorizers, |
| 'models': self.models |
| }, self.model_path) |
|
|
| def get_probabilities(self, texts: List[str], label: str) -> np.ndarray: |
| """Get raw probabilities for a specific label""" |
| X = self.vectorizers[label].transform(texts) |
| return self.models[label].predict_proba(X)[:, 1] |
|
|
| def validate(self, texts: List[str], label: str, threshold: float = 0.5) -> List[bool]: |
| """Validate texts using the fast model with a lower threshold of 0.5""" |
| |
| X = self.vectorizers[label].transform(texts) |
| |
| |
| probs = self.models[label].predict_proba(X)[:, 1] |
| |
| |
| return probs >= threshold |
|
|
| class ToxicAugmenter: |
| def __init__(self): |
| logger.info("Initializing ToxicAugmenter...") |
| |
| |
| self.generation_buffer = [] |
| self.buffer_size = 100 |
| |
| |
| self.num_gpus = torch.cuda.device_count() |
| if self.num_gpus > 0: |
| torch.backends.cuda.matmul.allow_tf32 = True |
| torch.backends.cudnn.allow_tf32 = True |
| logger.info(f"Found {self.num_gpus} GPUs:") |
| for i in range(self.num_gpus): |
| mem = torch.cuda.get_device_properties(i).total_memory / 1024**3 |
| logger.info(f"GPU {i}: {torch.cuda.get_device_name(i)} ({mem:.1f}GB)") |
| |
| |
| logger.info("Loading Mistral-7B...") |
| |
| |
| quantization_config = BitsAndBytesConfig( |
| bnb_4bit_compute_dtype=torch.float16, |
| bnb_4bit_quant_type="nf4", |
| bnb_4bit_use_double_quant=True |
| ) |
| |
| |
| self.llm_tokenizer = AutoTokenizer.from_pretrained( |
| "mistralai/Mistral-7B-Instruct-v0.3", |
| padding_side="left", |
| use_fast=True, |
| model_max_length=512, |
| pad_token="</s>", |
| add_eos_token=True, |
| add_bos_token=True |
| ) |
| |
| |
| self.llm_tokenizer.pad_token = self.llm_tokenizer.eos_token |
| self.llm_tokenizer.pad_token_id = self.llm_tokenizer.eos_token_id |
| |
| |
| self.llm = AutoModelForCausalLM.from_pretrained( |
| "mistralai/Mistral-7B-Instruct-v0.3", |
| device_map="balanced", |
| torch_dtype=torch.float16, |
| quantization_config=quantization_config, |
| max_memory={0: "22GB", 1: "22GB"}, |
| use_cache=True, |
| pad_token_id=self.llm_tokenizer.pad_token_id |
| ) |
| |
| logger.info("✓ Mistral-7B loaded") |
| |
| |
| self.validator = FastToxicValidator() |
| logger.info("✓ Fast validator initialized") |
|
|
| def generate_prompt(self, seed_text: str, label: str) -> str: |
| """Generate a prompt for a single toxicity type""" |
| base_prompt = """<s>[INST] You are tasked with generating a highly {type} comment. Your goal is to create content that would definitely be flagged as {type}. |
| |
| Example for reference (generate something different): "{seed_text}" |
| |
| Requirements: |
| - Must be extremely {type} in nature |
| - Use very strong language |
| - Must be different from the example |
| |
| Generate ONLY the {type} comment (no explanations): [/INST]""" |
|
|
| type_mapping = { |
| 'toxic': 'toxic and offensive', |
| 'severe_toxic': 'toxic and hostile', |
| 'obscene': 'obscene and vulgar', |
| 'threat': 'threatening and intimidating', |
| 'insult': 'insulting and demeaning', |
| 'identity_hate': 'discriminatory and hateful' |
| } |
|
|
| return base_prompt.format(type=type_mapping[label], seed_text=seed_text) |
|
|
| def flush_buffer(self): |
| """Flush the generation buffer to disk""" |
| if self.generation_buffer: |
| try: |
| with open(log_file, 'a', encoding='utf-8') as f: |
| for entry in self.generation_buffer: |
| f.write(json.dumps(entry, ensure_ascii=False) + '\n') |
| self.generation_buffer = [] |
| except Exception as e: |
| logger.error(f"Failed to flush buffer: {str(e)}") |
|
|
| def log_generation(self, seed_text: str, prompt: str, generated_text: str, validation_results: Dict[str, bool]): |
| """Buffer log generation details with proper JSON serialization""" |
| |
| serializable_results = { |
| k: bool(v) for k, v in validation_results.items() |
| } |
| |
| log_entry = { |
| "timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S"), |
| "seed_text": seed_text, |
| "prompt": prompt, |
| "generated_text": generated_text, |
| "validation_results": serializable_results |
| } |
| |
| self.generation_buffer.append(log_entry) |
| |
| |
| if len(self.generation_buffer) >= self.buffer_size: |
| self.flush_buffer() |
|
|
| def validate_sample(self, text: str, label: str, attempts: int) -> bool: |
| """Validate a generated sample for a single label""" |
| try: |
| |
| prob = self.validator.get_probabilities([text], label)[0] |
| |
| |
| thresholds = { |
| 'toxic': 0.3, |
| 'severe_toxic': 0.2, |
| 'obscene': 0.2, |
| 'threat': 0.2, |
| 'insult': 0.2, |
| 'identity_hate': 0.2 |
| } |
| threshold = thresholds.get(label, 0.2) |
| |
| |
| passed = prob >= threshold |
| if passed or (attempts % 5 == 0): |
| logger.info(f"\nValidation - Label: {label}, Text: {text}") |
| logger.info(f"Probability: {prob:.3f}, Threshold: {threshold:.2f}, Passed: {passed}") |
| |
| return passed |
| |
| except Exception as e: |
| logger.error(f"Validation error: {str(e)}") |
| return False |
|
|
| def generate_samples(self, target_samples: int, label: str, |
| seed_texts: List[str], total_timeout: int = 300) -> pd.DataFrame: |
| """Generate samples for a single label with timeouts""" |
| start_time = time.time() |
| generated_samples = [] |
| attempts = 0 |
| max_attempts = target_samples * 50 |
| batch_size = min(16, target_samples) |
| |
| pbar = tqdm(total=target_samples, desc=f"Generating {label} samples") |
| |
| try: |
| while len(generated_samples) < target_samples and attempts < max_attempts: |
| |
| if time.time() - start_time > total_timeout: |
| logger.warning(f"Generation timed out after {total_timeout} seconds") |
| break |
| |
| attempts += 1 |
| |
| |
| seed_text = random.choice(seed_texts) |
| prompt = self.generate_prompt(seed_text, label) |
| |
| try: |
| |
| inputs = self.llm_tokenizer(prompt, return_tensors="pt", padding=True, |
| truncation=True, max_length=512).to(self.llm.device) |
| |
| with torch.no_grad(): |
| outputs = self.llm.generate( |
| **inputs, |
| max_new_tokens=200, |
| num_beams=4, |
| temperature=1.35, |
| do_sample=True, |
| top_p=0.99, |
| top_k=200, |
| num_return_sequences=1, |
| repetition_penalty=1.0, |
| no_repeat_ngram_size=0, |
| early_stopping=True, |
| pad_token_id=self.llm_tokenizer.pad_token_id, |
| bos_token_id=self.llm_tokenizer.bos_token_id, |
| eos_token_id=self.llm_tokenizer.eos_token_id, |
| use_cache=True |
| ) |
| |
| text = self.llm_tokenizer.decode(outputs[0], skip_special_tokens=True) |
| |
| |
| if "[/INST]" in text: |
| output = text.split("[/INST]")[1].strip() |
| output = output.strip().strip('"').strip("'") |
| |
| |
| if len(output) >= 10: |
| |
| if attempts % 5 == 0: |
| logger.info(f"\nAttempt {attempts}: Generated text: {output}") |
| |
| |
| if self.validate_sample(output, label, attempts): |
| sample_dict = {'comment_text': output} |
| for l in ['toxic', 'severe_toxic', 'obscene', 'threat', 'insult', 'identity_hate']: |
| sample_dict[l] = 1 if l == label else 0 |
| generated_samples.append(sample_dict) |
| pbar.update(1) |
| logger.info(f"✓ Valid {label} sample generated ({len(generated_samples)}/{target_samples})") |
| |
| except Exception as e: |
| logger.error(f"Generation error on attempt {attempts}: {str(e)}") |
| continue |
| |
| |
| if attempts % 200 == 0: |
| torch.cuda.empty_cache() |
| gc.collect() |
| |
| finally: |
| pbar.close() |
| logger.info(f"Generation finished: {len(generated_samples)}/{target_samples} samples in {attempts} attempts") |
| |
| |
| if generated_samples: |
| return pd.DataFrame(generated_samples) |
| return None |
|
|
| def augment_dataset(self, target_samples: int, label: str, seed_texts: List[str], timeout_minutes: int = 5) -> pd.DataFrame: |
| """Generate a specific number of samples with given label combination""" |
| logger.info(f"\nGenerating {target_samples} samples with label: {label}") |
| |
| generated_samples = [] |
| batch_size = min(32, target_samples) |
| start_time = time.time() |
| timeout_seconds = min(timeout_minutes * 60, 300) |
| total_generated = 0 |
| pbar = None |
| |
| try: |
| |
| pbar = tqdm( |
| total=target_samples, |
| desc="Generating", |
| unit="samples", |
| ncols=100, |
| bar_format='{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]' |
| ) |
| |
| while total_generated < target_samples: |
| |
| elapsed_time = time.time() - start_time |
| if elapsed_time > timeout_seconds: |
| logger.warning(f"Time limit reached after {elapsed_time/60:.1f} minutes") |
| break |
| |
| |
| remaining = target_samples - total_generated |
| current_batch_size = min(batch_size, remaining) |
| |
| |
| batch_seeds = np.random.choice(seed_texts, size=current_batch_size) |
| prompts = [self.generate_prompt(seed, label) for seed in batch_seeds] |
| |
| |
| batch_start = time.time() |
| new_samples = self.generate_samples( |
| target_samples=current_batch_size, |
| label=label, |
| seed_texts=batch_seeds, |
| total_timeout=timeout_seconds - elapsed_time |
| ) |
| |
| if new_samples is not None and not new_samples.empty: |
| if len(new_samples) > remaining: |
| new_samples = new_samples.head(remaining) |
| |
| generated_samples.append(new_samples) |
| num_new = len(new_samples) |
| total_generated += num_new |
| |
| |
| pbar.update(num_new) |
| |
| |
| elapsed_minutes = elapsed_time / 60 |
| rate = total_generated / elapsed_minutes if elapsed_minutes > 0 else 0 |
| batch_time = time.time() - batch_start |
| time_remaining = max(0, timeout_seconds - elapsed_time) |
| |
| pbar.set_postfix({ |
| 'rate': f'{rate:.1f}/min', |
| 'batch': f'{batch_time:.1f}s', |
| 'remain': f'{time_remaining:.0f}s' |
| }, refresh=True) |
| |
| |
| if total_generated % (batch_size * 4) == 0: |
| torch.cuda.empty_cache() |
| |
| |
| if generated_samples: |
| final_df = pd.concat(generated_samples, ignore_index=True) |
| if len(final_df) > target_samples: |
| final_df = final_df.head(target_samples) |
| logger.info(f"Successfully generated {len(final_df)} samples in {elapsed_time/60:.1f} minutes") |
| return final_df |
| |
| return None |
| |
| except Exception as e: |
| logger.error(f"Generation error: {str(e)}") |
| return None |
| finally: |
| if pbar is not None: |
| pbar.close() |
| |
| self.flush_buffer() |
| torch.cuda.empty_cache() |