depscreen / ml /scripts /augment_dl_validated.py
halsabbah's picture
style: apply ruff format to pass CI format check
95974bc
"""
Generate + validate augmented data using the DL model itself as filter.
Strategy: Generate sentences via LLM, then keep only those our DistilBERT
model classifies as the target class (top-3 predictions, prob > 0.1).
This produces training data the model can learn from — it reinforces
patterns the model already partially recognizes.
Usage:
python augment_dl_validated.py
"""
import asyncio
import json
import os
import re
import sys
from pathlib import Path
import pandas as pd
import torch
import torch.nn.functional as F
from openai import AsyncOpenAI
from transformers import AutoTokenizer
sys.path.insert(0, str(Path(__file__).parent))
from train_redsm5_model import SymptomClassifier
os.environ["TOKENIZERS_PARALLELISM"] = "false"
TARGETS = {
"PSYCHOMOTOR": {
"label_id": 4,
"count": 100,
"definition": (
"Observable slowing of physical movement/speech (retardation) "
"OR restlessness/agitation (pacing, fidgeting, can't sit still). "
"Must be PHYSICAL/OBSERVABLE, not just feeling tired."
),
},
"COGNITIVE_ISSUES": {
"label_id": 7,
"count": 100,
"definition": (
"Difficulty concentrating, brain fog, indecisiveness, memory problems, "
"losing track of conversations, feeling mentally slow/dull."
),
},
}
async def generate_sentences(client, symptom, definition, count, model="gemini-3-flash-preview"):
all_sentences = []
prompt = (
f"Generate {count} unique Reddit-style first-person sentences where someone describes "
f"{symptom.replace('_', ' ').lower()} symptoms.\n\n"
f"Definition: {definition}\n\n"
f"Use informal language. Each must be a single sentence. Vary intensity and vocabulary.\n"
f"Return ONLY a JSON array of strings."
)
for _ in range(0, count, 25):
for _attempt in range(3):
try:
r = await client.chat.completions.create(
model=model,
messages=[
{"role": "system", "content": "Return ONLY valid JSON."},
{"role": "user", "content": prompt},
],
max_tokens=4096,
temperature=0.9,
)
content = r.choices[0].message.content or ""
content = re.sub(r"```json\s*", "", content)
content = re.sub(r"```\s*$", "", content)
match = re.search(r"\[.*\]", content, re.DOTALL)
if match:
sents = json.loads(match.group())
all_sentences.extend([s for s in sents if isinstance(s, str) and len(s) > 20])
break
except Exception:
await asyncio.sleep(1)
await asyncio.sleep(0.5)
return list(set(all_sentences))
def validate_with_model(sentences, target_label_id, model, tokenizer, device, top_k=3, min_prob=0.1):
"""Keep sentences where our DL model puts target class in top-K predictions."""
passed = []
model.eval()
with torch.no_grad():
for sent in sentences:
enc = tokenizer(sent, truncation=True, max_length=128, return_tensors="pt")
input_ids = enc["input_ids"].to(device)
attention_mask = enc["attention_mask"].to(device)
logits = model(input_ids, attention_mask)
probs = F.softmax(logits, dim=1)[0]
top_k_ids = torch.topk(probs, top_k).indices.tolist()
target_prob = probs[target_label_id].item()
if target_label_id in top_k_ids and target_prob > min_prob:
passed.append((sent, target_prob))
return sorted(passed, key=lambda x: -x[1])
async def main():
base_dir = Path(__file__).parent.parent
device = torch.device("mps") if torch.backends.mps.is_available() else torch.device("cpu")
# Load model
print("Loading DL model for validation...")
with open(base_dir / "models" / "baseline_v1" / "redsm5_metadata.json") as f:
meta = json.load(f)
label_map = meta["label_map"]
dl_model = SymptomClassifier(num_classes=11, model_name=str(base_dir / "models" / "v2_dapt_base"))
dl_model.load_state_dict(
torch.load(base_dir / "models" / "baseline_v1" / "symptom_classifier.pt", map_location=device)
)
dl_model.to(device)
tokenizer = AutoTokenizer.from_pretrained(str(base_dir / "models" / "v2_dapt_base"))
# LLM client
api_key = ""
env_path = base_dir.parent / ".env"
for line in env_path.read_text().splitlines():
if line.startswith("LLM_API_KEY="):
api_key = line.split("=", 1)[1].strip()
break
client = AsyncOpenAI(api_key=api_key, base_url="https://generativelanguage.googleapis.com/v1beta/openai/")
all_new = []
for symptom, config in TARGETS.items():
print(f"\n{'=' * 50}")
print(f"{symptom} (label_id={config['label_id']})")
# Generate
print(f" Generating {config['count']} candidates...")
candidates = await generate_sentences(client, symptom, config["definition"], config["count"])
print(f" Generated: {len(candidates)} unique")
# DL-model validate
print(" Validating with DL model (top-3, min_prob=0.1)...")
passed = validate_with_model(candidates, config["label_id"], dl_model, tokenizer, device)
print(f" Passed: {len(passed)}/{len(candidates)} ({len(passed) / max(len(candidates), 1) * 100:.0f}%)")
for s, p in passed[:5]:
print(f' prob={p:.3f} "{s[:70]}"')
for sent, prob in passed:
all_new.append(
{
"post_id": f"dlval_{symptom.lower()}_{len(all_new)}",
"sentence_id": f"dlval_s_{symptom.lower()}_{len(all_new)}",
"sentence_text": sent,
"clean_text": sent,
"label": symptom,
"label_id": config["label_id"],
"source": "dl_validated",
"similarity_score": prob,
}
)
new_df = pd.DataFrame(all_new)
# Merge with existing augmented
aug_dir = base_dir / "data" / "redsm5" / "augmented_v4"
for fname in ["augmented_samples_final.csv", "augmented_samples.csv"]:
aug_path = aug_dir / fname
if aug_path.exists():
existing = pd.read_csv(aug_path)
combined = pd.concat([existing, new_df], ignore_index=True)
combined = combined.drop_duplicates(subset=["clean_text"], keep="first")
combined.to_csv(aug_dir / "augmented_samples_final.csv", index=False)
break
else:
combined = new_df
combined.to_csv(aug_dir / "augmented_samples_final.csv", index=False)
print(f"\n{'=' * 50}")
print("COMPLETE")
print(f"New DL-validated: {len(new_df)}")
for sym in TARGETS:
count = len(combined[combined["label"] == sym])
print(f" {sym}: {count} total augmented")
print(f"Total augmented: {len(combined)}")
print(f"Saved to: {aug_dir / 'augmented_samples_final.csv'}")
del dl_model
if device.type == "mps":
torch.mps.empty_cache()
if __name__ == "__main__":
asyncio.run(main())