| """ |
| Generate + validate augmented data using the DL model itself as filter. |
| |
| Strategy: Generate sentences via LLM, then keep only those our DistilBERT |
| model classifies as the target class (top-3 predictions, prob > 0.1). |
| |
| This produces training data the model can learn from — it reinforces |
| patterns the model already partially recognizes. |
| |
| Usage: |
| python augment_dl_validated.py |
| """ |
|
|
| import asyncio |
| import json |
| import os |
| import re |
| import sys |
| from pathlib import Path |
|
|
| import pandas as pd |
| import torch |
| import torch.nn.functional as F |
| from openai import AsyncOpenAI |
| from transformers import AutoTokenizer |
|
|
| sys.path.insert(0, str(Path(__file__).parent)) |
| from train_redsm5_model import SymptomClassifier |
|
|
| os.environ["TOKENIZERS_PARALLELISM"] = "false" |
|
|
| TARGETS = { |
| "PSYCHOMOTOR": { |
| "label_id": 4, |
| "count": 100, |
| "definition": ( |
| "Observable slowing of physical movement/speech (retardation) " |
| "OR restlessness/agitation (pacing, fidgeting, can't sit still). " |
| "Must be PHYSICAL/OBSERVABLE, not just feeling tired." |
| ), |
| }, |
| "COGNITIVE_ISSUES": { |
| "label_id": 7, |
| "count": 100, |
| "definition": ( |
| "Difficulty concentrating, brain fog, indecisiveness, memory problems, " |
| "losing track of conversations, feeling mentally slow/dull." |
| ), |
| }, |
| } |
|
|
|
|
| async def generate_sentences(client, symptom, definition, count, model="gemini-3-flash-preview"): |
| all_sentences = [] |
| prompt = ( |
| f"Generate {count} unique Reddit-style first-person sentences where someone describes " |
| f"{symptom.replace('_', ' ').lower()} symptoms.\n\n" |
| f"Definition: {definition}\n\n" |
| f"Use informal language. Each must be a single sentence. Vary intensity and vocabulary.\n" |
| f"Return ONLY a JSON array of strings." |
| ) |
| for _ in range(0, count, 25): |
| for _attempt in range(3): |
| try: |
| r = await client.chat.completions.create( |
| model=model, |
| messages=[ |
| {"role": "system", "content": "Return ONLY valid JSON."}, |
| {"role": "user", "content": prompt}, |
| ], |
| max_tokens=4096, |
| temperature=0.9, |
| ) |
| content = r.choices[0].message.content or "" |
| content = re.sub(r"```json\s*", "", content) |
| content = re.sub(r"```\s*$", "", content) |
| match = re.search(r"\[.*\]", content, re.DOTALL) |
| if match: |
| sents = json.loads(match.group()) |
| all_sentences.extend([s for s in sents if isinstance(s, str) and len(s) > 20]) |
| break |
| except Exception: |
| await asyncio.sleep(1) |
| await asyncio.sleep(0.5) |
| return list(set(all_sentences)) |
|
|
|
|
| def validate_with_model(sentences, target_label_id, model, tokenizer, device, top_k=3, min_prob=0.1): |
| """Keep sentences where our DL model puts target class in top-K predictions.""" |
| passed = [] |
| model.eval() |
| with torch.no_grad(): |
| for sent in sentences: |
| enc = tokenizer(sent, truncation=True, max_length=128, return_tensors="pt") |
| input_ids = enc["input_ids"].to(device) |
| attention_mask = enc["attention_mask"].to(device) |
| logits = model(input_ids, attention_mask) |
| probs = F.softmax(logits, dim=1)[0] |
| top_k_ids = torch.topk(probs, top_k).indices.tolist() |
| target_prob = probs[target_label_id].item() |
|
|
| if target_label_id in top_k_ids and target_prob > min_prob: |
| passed.append((sent, target_prob)) |
| return sorted(passed, key=lambda x: -x[1]) |
|
|
|
|
| async def main(): |
| base_dir = Path(__file__).parent.parent |
| device = torch.device("mps") if torch.backends.mps.is_available() else torch.device("cpu") |
|
|
| |
| print("Loading DL model for validation...") |
| with open(base_dir / "models" / "baseline_v1" / "redsm5_metadata.json") as f: |
| meta = json.load(f) |
| label_map = meta["label_map"] |
|
|
| dl_model = SymptomClassifier(num_classes=11, model_name=str(base_dir / "models" / "v2_dapt_base")) |
| dl_model.load_state_dict( |
| torch.load(base_dir / "models" / "baseline_v1" / "symptom_classifier.pt", map_location=device) |
| ) |
| dl_model.to(device) |
| tokenizer = AutoTokenizer.from_pretrained(str(base_dir / "models" / "v2_dapt_base")) |
|
|
| |
| api_key = "" |
| env_path = base_dir.parent / ".env" |
| for line in env_path.read_text().splitlines(): |
| if line.startswith("LLM_API_KEY="): |
| api_key = line.split("=", 1)[1].strip() |
| break |
| client = AsyncOpenAI(api_key=api_key, base_url="https://generativelanguage.googleapis.com/v1beta/openai/") |
|
|
| all_new = [] |
|
|
| for symptom, config in TARGETS.items(): |
| print(f"\n{'=' * 50}") |
| print(f"{symptom} (label_id={config['label_id']})") |
|
|
| |
| print(f" Generating {config['count']} candidates...") |
| candidates = await generate_sentences(client, symptom, config["definition"], config["count"]) |
| print(f" Generated: {len(candidates)} unique") |
|
|
| |
| print(" Validating with DL model (top-3, min_prob=0.1)...") |
| passed = validate_with_model(candidates, config["label_id"], dl_model, tokenizer, device) |
| print(f" Passed: {len(passed)}/{len(candidates)} ({len(passed) / max(len(candidates), 1) * 100:.0f}%)") |
|
|
| for s, p in passed[:5]: |
| print(f' prob={p:.3f} "{s[:70]}"') |
|
|
| for sent, prob in passed: |
| all_new.append( |
| { |
| "post_id": f"dlval_{symptom.lower()}_{len(all_new)}", |
| "sentence_id": f"dlval_s_{symptom.lower()}_{len(all_new)}", |
| "sentence_text": sent, |
| "clean_text": sent, |
| "label": symptom, |
| "label_id": config["label_id"], |
| "source": "dl_validated", |
| "similarity_score": prob, |
| } |
| ) |
|
|
| new_df = pd.DataFrame(all_new) |
|
|
| |
| aug_dir = base_dir / "data" / "redsm5" / "augmented_v4" |
| for fname in ["augmented_samples_final.csv", "augmented_samples.csv"]: |
| aug_path = aug_dir / fname |
| if aug_path.exists(): |
| existing = pd.read_csv(aug_path) |
| combined = pd.concat([existing, new_df], ignore_index=True) |
| combined = combined.drop_duplicates(subset=["clean_text"], keep="first") |
| combined.to_csv(aug_dir / "augmented_samples_final.csv", index=False) |
| break |
| else: |
| combined = new_df |
| combined.to_csv(aug_dir / "augmented_samples_final.csv", index=False) |
|
|
| print(f"\n{'=' * 50}") |
| print("COMPLETE") |
| print(f"New DL-validated: {len(new_df)}") |
| for sym in TARGETS: |
| count = len(combined[combined["label"] == sym]) |
| print(f" {sym}: {count} total augmented") |
| print(f"Total augmented: {len(combined)}") |
| print(f"Saved to: {aug_dir / 'augmented_samples_final.csv'}") |
|
|
| del dl_model |
| if device.type == "mps": |
| torch.mps.empty_cache() |
|
|
|
|
| if __name__ == "__main__": |
| asyncio.run(main()) |
|
|