| import torch |
| from transformers import ( |
| AutoModelForCausalLM, |
| AutoTokenizer, |
| BitsAndBytesConfig |
| ) |
| from langdetect import detect |
| import pandas as pd |
| import numpy as np |
| from tqdm import tqdm |
| from pathlib import Path |
| import logging |
| import gc |
| from typing import List |
| import json |
| from datetime import datetime, timedelta |
| import time |
| import sys |
| from sklearn.feature_extraction.text import TfidfVectorizer |
| from sklearn.linear_model import LogisticRegression |
| import joblib |
|
|
| |
| log_dir = Path("logs") |
| log_dir.mkdir(exist_ok=True) |
|
|
| |
| timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") |
| log_file = log_dir / f"generation_{timestamp}.log" |
|
|
| |
| logging.basicConfig( |
| level=logging.INFO, |
| format='%(asctime)s | %(message)s', |
| handlers=[ |
| logging.StreamHandler(sys.stdout), |
| logging.FileHandler(log_file) |
| ] |
| ) |
|
|
| logger = logging.getLogger(__name__) |
| logger.info(f"Starting new run. Log file: {log_file}") |
|
|
| def log_separator(message: str = ""): |
| """Print a separator line with optional message""" |
| if message: |
| logger.info("\n" + "="*40 + f" {message} " + "="*40) |
| else: |
| logger.info("\n" + "="*100) |
|
|
| class FastThreatValidator: |
| """Fast threat validation using logistic regression""" |
| def __init__(self, model_path: str = "weights/threat_validator.joblib"): |
| self.model_path = model_path |
| if Path(model_path).exists(): |
| logger.info("Loading fast threat validator...") |
| model_data = joblib.load(model_path) |
| self.vectorizer = model_data['vectorizer'] |
| self.model = model_data['model'] |
| logger.info("✓ Fast validator loaded") |
| else: |
| logger.info("Training fast threat validator...") |
| self._train_validator() |
| logger.info("✓ Fast validator trained and saved") |
| |
| def _train_validator(self): |
| """Train a simple logistic regression model for threat detection""" |
| |
| train_df = pd.read_csv("dataset/split/train.csv") |
| |
| |
| X = train_df['comment_text'].fillna('') |
| y = train_df['threat'] |
| |
| |
| self.vectorizer = TfidfVectorizer( |
| max_features=10000, |
| ngram_range=(1, 2), |
| strip_accents='unicode', |
| min_df=2 |
| ) |
| X_vec = self.vectorizer.fit_transform(X) |
| |
| |
| self.model = LogisticRegression( |
| C=1.0, |
| class_weight='balanced', |
| max_iter=200, |
| n_jobs=-1 |
| ) |
| self.model.fit(X_vec, y) |
| |
| |
| joblib.dump({ |
| 'vectorizer': self.vectorizer, |
| 'model': self.model |
| }, self.model_path) |
| |
| def validate(self, texts: List[str], threshold: float = 0.6) -> List[bool]: |
| """Validate texts using the fast model""" |
| |
| X = self.vectorizer.transform(texts) |
| |
| |
| probs = self.model.predict_proba(X)[:, 1] |
| |
| |
| return probs >= threshold |
|
|
| class ThreatAugmenter: |
| def __init__(self, seed_samples_path: str = "dataset/processed/MULTILINGUAL_TOXIC_DATASET_360K_7LANG_FINAL.csv"): |
| log_separator("INITIALIZATION") |
| |
| |
| self.log_file = log_file |
| |
| |
| self.generation_buffer = [] |
| self.buffer_size = 100 |
| |
| |
| self.num_gpus = torch.cuda.device_count() |
| if self.num_gpus > 0: |
| torch.backends.cuda.matmul.allow_tf32 = True |
| torch.backends.cudnn.allow_tf32 = True |
| logger.info(f"Found {self.num_gpus} GPUs:") |
| for i in range(self.num_gpus): |
| mem = torch.cuda.get_device_properties(i).total_memory / 1024**3 |
| logger.info(f"GPU {i}: {torch.cuda.get_device_name(i)} ({mem:.1f}GB)") |
| |
| |
| log_separator("LOADING MODELS") |
| logger.info("Loading Mistral-7B...") |
| |
| |
| quantization_config = BitsAndBytesConfig( |
| load_in_4bit=True, |
| bnb_4bit_compute_dtype=torch.float16, |
| bnb_4bit_quant_type="nf4", |
| bnb_4bit_use_double_quant=True |
| ) |
| |
| self.llm = AutoModelForCausalLM.from_pretrained( |
| "mistralai/Mistral-7B-Instruct-v0.3", |
| device_map="balanced", |
| torch_dtype=torch.float16, |
| quantization_config=quantization_config, |
| max_memory={0: "22GB", 1: "22GB"} |
| ) |
| |
| self.llm_tokenizer = AutoTokenizer.from_pretrained( |
| "mistralai/Mistral-7B-Instruct-v0.3", |
| padding_side="left", |
| use_fast=True |
| ) |
| self.llm_tokenizer.pad_token = self.llm_tokenizer.eos_token |
| logger.info("✓ Mistral-7B loaded") |
| |
| |
| self.validator = FastThreatValidator() |
| logger.info("✓ Fast validator initialized") |
| |
| |
| log_separator("LOADING DATA") |
| logger.info(f"Loading samples from {seed_samples_path}") |
| self.seed_df = pd.read_csv(seed_samples_path) |
| self.en_threat_samples = self.seed_df[ |
| (self.seed_df['lang'] == 'en') & |
| (self.seed_df['threat'] == 1) |
| ] |
| logger.info(f"Found {len(self.en_threat_samples)} EN threat samples") |
| |
| |
| self.max_batch_size = 48 |
| self.prefetch_factor = 4 |
| self.num_workers = 8 |
| |
| |
| torch.cuda.empty_cache() |
| gc.collect() |
| log_separator() |
|
|
| def generate_prompt(self, seed_text: str) -> str: |
| """Generate a prompt for realistic threatening content""" |
| return f"""<s>[INST] Generate a threatening social media comment that is realistic and intimidating. |
| |
| Reference example (generate something different): {seed_text} |
| |
| Requirements: |
| - Make it genuinely threatening |
| - Use authentic social media language |
| - Keep it under 50 words |
| - Must be different from example |
| |
| Generate ONLY the comment: [/INST]""" |
|
|
| def flush_buffer(self): |
| """Flush the generation buffer to disk""" |
| if self.generation_buffer: |
| try: |
| with open(self.log_file, 'a', encoding='utf-8') as f: |
| for entry in self.generation_buffer: |
| f.write(json.dumps(entry, ensure_ascii=False) + '\n') |
| self.generation_buffer = [] |
| except Exception as e: |
| logger.error(f"Failed to flush buffer: {str(e)}") |
|
|
| def log_generation(self, seed_text: str, prompt: str, generated_text: str, is_valid: bool): |
| """Buffer log generation details""" |
| log_entry = { |
| "timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S"), |
| "seed_text": seed_text, |
| "prompt": prompt, |
| "generated_text": generated_text, |
| "is_valid": is_valid |
| } |
| |
| self.generation_buffer.append(log_entry) |
| |
| |
| if len(self.generation_buffer) >= self.buffer_size: |
| self.flush_buffer() |
|
|
| def generate_samples(self, prompts: List[str], seed_texts: List[str]) -> List[str]: |
| try: |
| with torch.amp.autocast('cuda', dtype=torch.float16): |
| inputs = self.llm_tokenizer(prompts, return_tensors="pt", padding=True, |
| truncation=True, max_length=256).to(self.llm.device) |
| |
| outputs = self.llm.generate( |
| **inputs, |
| max_new_tokens=32, |
| temperature=0.95, |
| do_sample=True, |
| top_p=0.92, |
| top_k=50, |
| num_return_sequences=1, |
| repetition_penalty=1.15, |
| pad_token_id=self.llm_tokenizer.pad_token_id, |
| eos_token_id=self.llm_tokenizer.eos_token_id |
| ) |
| |
| texts = self.llm_tokenizer.batch_decode(outputs, skip_special_tokens=False) |
| cleaned_texts = [] |
| valid_count = 0 |
| |
| |
| for idx, text in enumerate(texts): |
| if "[/INST]" in text and "</s>" in text: |
| response = text.split("[/INST]")[1].split("</s>")[0].strip() |
| response = response.strip().strip('"').strip("'") |
| |
| word_count = len(response.split()) |
| if (word_count >= 3 and word_count <= 50 and |
| not any(x in response.lower() for x in [ |
| "generate", "requirements:", "reference", |
| "[inst]", "example" |
| ])): |
| cleaned_texts.append(response) |
| valid_count += 1 |
| |
| |
| if valid_count > 0: |
| logger.info(f"\nBatch Success: {valid_count}/{len(texts)} ({valid_count/len(texts)*100:.1f}%)") |
| |
| return cleaned_texts |
| |
| except Exception as e: |
| logger.error(f"Generation error: {str(e)}") |
| return [] |
|
|
| def validate_toxicity(self, texts: List[str]) -> torch.Tensor: |
| """Validate texts using fast logistic regression""" |
| if not texts: |
| return torch.zeros(0, dtype=torch.bool) |
| |
| |
| validation_mask = self.validator.validate(texts) |
| |
| |
| return torch.tensor(validation_mask, dtype=torch.bool, device=self.llm.device) |
| |
| def validate_language(self, texts: List[str]) -> List[bool]: |
| """Simple language validation""" |
| return [detect(text) == 'en' for text in texts] |
| |
| def augment_dataset(self, target_samples: int = 500, batch_size: int = 32): |
| """Main augmentation loop with progress bar and CSV saving""" |
| try: |
| start_time = time.time() |
| logger.info(f"Starting generation: target={target_samples}, batch_size={batch_size}") |
| generated_samples = [] |
| stats = { |
| "total_attempts": 0, |
| "valid_samples": 0, |
| "batch_times": [] |
| } |
| |
| |
| output_dir = Path("dataset/augmented") |
| output_dir.mkdir(parents=True, exist_ok=True) |
| |
| |
| timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") |
| output_file = output_dir / f"threat_augmented_{timestamp}.csv" |
| |
| |
| pbar = tqdm(total=target_samples, |
| desc="Generating samples", |
| unit="samples", |
| ncols=100, |
| bar_format='{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}]') |
| |
| while len(generated_samples) < target_samples: |
| batch_start = time.time() |
| |
| seed_texts = self.en_threat_samples['comment_text'].sample(batch_size).tolist() |
| prompts = [self.generate_prompt(text) for text in seed_texts] |
| new_samples = self.generate_samples(prompts, seed_texts) |
| |
| if not new_samples: |
| continue |
| |
| |
| batch_time = time.time() - batch_start |
| stats["batch_times"].append(batch_time) |
| stats["total_attempts"] += len(new_samples) |
| prev_len = len(generated_samples) |
| generated_samples.extend(new_samples) |
| stats["valid_samples"] = len(generated_samples) |
| |
| |
| pbar.update(len(generated_samples) - prev_len) |
| |
| |
| if len(stats["batch_times"]) % 10 == 0: |
| success_rate = (stats["valid_samples"] / stats["total_attempts"]) * 100 |
| avg_batch_time = sum(stats["batch_times"][-20:]) / min(len(stats["batch_times"]), 20) |
| pbar.set_postfix({ |
| 'Success Rate': f'{success_rate:.1f}%', |
| 'Batch Time': f'{avg_batch_time:.2f}s' |
| }) |
| |
| |
| if len(generated_samples) % (batch_size * 5) == 0: |
| torch.cuda.empty_cache() |
| gc.collect() |
| |
| |
| pbar.close() |
| |
| |
| df = pd.DataFrame({ |
| 'text': generated_samples[:target_samples], |
| 'label': 1, |
| 'source': 'augmented', |
| 'timestamp': timestamp |
| }) |
| |
| |
| df.to_csv(output_file, index=False) |
| logger.info(f"\nSaved {len(df)} samples to {output_file}") |
| |
| |
| total_time = str(timedelta(seconds=int(time.time() - start_time))) |
| logger.info(f"Generation complete: {len(generated_samples)} samples generated in {total_time}") |
| |
| return df |
| |
| except Exception as e: |
| logger.error(f"Generation failed: {str(e)}") |
| raise |
|
|
| if __name__ == "__main__": |
| torch.cuda.empty_cache() |
| gc.collect() |
| |
| augmenter = ThreatAugmenter() |
| augmented_df = augmenter.augment_dataset(target_samples=500) |