File size: 7,342 Bytes
3549219
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2ad523a
3549219
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2ad523a
3549219
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95974bc
 
 
3549219
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95974bc
3549219
 
 
 
 
 
 
 
2ad523a
3549219
95974bc
3549219
 
95974bc
3549219
 
95974bc
 
 
 
 
 
 
 
 
 
 
 
3549219
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95974bc
2ad523a
3549219
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
"""
Generate + validate augmented data using the DL model itself as filter.

Strategy: Generate sentences via LLM, then keep only those our DistilBERT
model classifies as the target class (top-3 predictions, prob > 0.1).

This produces training data the model can learn from — it reinforces
patterns the model already partially recognizes.

Usage:
    python augment_dl_validated.py
"""

import asyncio
import json
import os
import re
import sys
from pathlib import Path

import pandas as pd
import torch
import torch.nn.functional as F
from openai import AsyncOpenAI
from transformers import AutoTokenizer

sys.path.insert(0, str(Path(__file__).parent))
from train_redsm5_model import SymptomClassifier

os.environ["TOKENIZERS_PARALLELISM"] = "false"

TARGETS = {
    "PSYCHOMOTOR": {
        "label_id": 4,
        "count": 100,
        "definition": (
            "Observable slowing of physical movement/speech (retardation) "
            "OR restlessness/agitation (pacing, fidgeting, can't sit still). "
            "Must be PHYSICAL/OBSERVABLE, not just feeling tired."
        ),
    },
    "COGNITIVE_ISSUES": {
        "label_id": 7,
        "count": 100,
        "definition": (
            "Difficulty concentrating, brain fog, indecisiveness, memory problems, "
            "losing track of conversations, feeling mentally slow/dull."
        ),
    },
}


async def generate_sentences(client, symptom, definition, count, model="gemini-3-flash-preview"):
    all_sentences = []
    prompt = (
        f"Generate {count} unique Reddit-style first-person sentences where someone describes "
        f"{symptom.replace('_', ' ').lower()} symptoms.\n\n"
        f"Definition: {definition}\n\n"
        f"Use informal language. Each must be a single sentence. Vary intensity and vocabulary.\n"
        f"Return ONLY a JSON array of strings."
    )
    for _ in range(0, count, 25):
        for _attempt in range(3):
            try:
                r = await client.chat.completions.create(
                    model=model,
                    messages=[
                        {"role": "system", "content": "Return ONLY valid JSON."},
                        {"role": "user", "content": prompt},
                    ],
                    max_tokens=4096,
                    temperature=0.9,
                )
                content = r.choices[0].message.content or ""
                content = re.sub(r"```json\s*", "", content)
                content = re.sub(r"```\s*$", "", content)
                match = re.search(r"\[.*\]", content, re.DOTALL)
                if match:
                    sents = json.loads(match.group())
                    all_sentences.extend([s for s in sents if isinstance(s, str) and len(s) > 20])
                    break
            except Exception:
                await asyncio.sleep(1)
        await asyncio.sleep(0.5)
    return list(set(all_sentences))


def validate_with_model(sentences, target_label_id, model, tokenizer, device, top_k=3, min_prob=0.1):
    """Keep sentences where our DL model puts target class in top-K predictions."""
    passed = []
    model.eval()
    with torch.no_grad():
        for sent in sentences:
            enc = tokenizer(sent, truncation=True, max_length=128, return_tensors="pt")
            input_ids = enc["input_ids"].to(device)
            attention_mask = enc["attention_mask"].to(device)
            logits = model(input_ids, attention_mask)
            probs = F.softmax(logits, dim=1)[0]
            top_k_ids = torch.topk(probs, top_k).indices.tolist()
            target_prob = probs[target_label_id].item()

            if target_label_id in top_k_ids and target_prob > min_prob:
                passed.append((sent, target_prob))
    return sorted(passed, key=lambda x: -x[1])


async def main():
    base_dir = Path(__file__).parent.parent
    device = torch.device("mps") if torch.backends.mps.is_available() else torch.device("cpu")

    # Load model
    print("Loading DL model for validation...")
    with open(base_dir / "models" / "baseline_v1" / "redsm5_metadata.json") as f:
        meta = json.load(f)
    label_map = meta["label_map"]

    dl_model = SymptomClassifier(num_classes=11, model_name=str(base_dir / "models" / "v2_dapt_base"))
    dl_model.load_state_dict(
        torch.load(base_dir / "models" / "baseline_v1" / "symptom_classifier.pt", map_location=device)
    )
    dl_model.to(device)
    tokenizer = AutoTokenizer.from_pretrained(str(base_dir / "models" / "v2_dapt_base"))

    # LLM client
    api_key = ""
    env_path = base_dir.parent / ".env"
    for line in env_path.read_text().splitlines():
        if line.startswith("LLM_API_KEY="):
            api_key = line.split("=", 1)[1].strip()
            break
    client = AsyncOpenAI(api_key=api_key, base_url="https://generativelanguage.googleapis.com/v1beta/openai/")

    all_new = []

    for symptom, config in TARGETS.items():
        print(f"\n{'=' * 50}")
        print(f"{symptom} (label_id={config['label_id']})")

        # Generate
        print(f"  Generating {config['count']} candidates...")
        candidates = await generate_sentences(client, symptom, config["definition"], config["count"])
        print(f"  Generated: {len(candidates)} unique")

        # DL-model validate
        print("  Validating with DL model (top-3, min_prob=0.1)...")
        passed = validate_with_model(candidates, config["label_id"], dl_model, tokenizer, device)
        print(f"  Passed: {len(passed)}/{len(candidates)} ({len(passed) / max(len(candidates), 1) * 100:.0f}%)")

        for s, p in passed[:5]:
            print(f'    prob={p:.3f} "{s[:70]}"')

        for sent, prob in passed:
            all_new.append(
                {
                    "post_id": f"dlval_{symptom.lower()}_{len(all_new)}",
                    "sentence_id": f"dlval_s_{symptom.lower()}_{len(all_new)}",
                    "sentence_text": sent,
                    "clean_text": sent,
                    "label": symptom,
                    "label_id": config["label_id"],
                    "source": "dl_validated",
                    "similarity_score": prob,
                }
            )

    new_df = pd.DataFrame(all_new)

    # Merge with existing augmented
    aug_dir = base_dir / "data" / "redsm5" / "augmented_v4"
    for fname in ["augmented_samples_final.csv", "augmented_samples.csv"]:
        aug_path = aug_dir / fname
        if aug_path.exists():
            existing = pd.read_csv(aug_path)
            combined = pd.concat([existing, new_df], ignore_index=True)
            combined = combined.drop_duplicates(subset=["clean_text"], keep="first")
            combined.to_csv(aug_dir / "augmented_samples_final.csv", index=False)
            break
    else:
        combined = new_df
        combined.to_csv(aug_dir / "augmented_samples_final.csv", index=False)

    print(f"\n{'=' * 50}")
    print("COMPLETE")
    print(f"New DL-validated: {len(new_df)}")
    for sym in TARGETS:
        count = len(combined[combined["label"] == sym])
        print(f"  {sym}: {count} total augmented")
    print(f"Total augmented: {len(combined)}")
    print(f"Saved to: {aug_dir / 'augmented_samples_final.csv'}")

    del dl_model
    if device.type == "mps":
        torch.mps.empty_cache()


if __name__ == "__main__":
    asyncio.run(main())