File size: 7,698 Bytes
62ff3c4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d6c18ca
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
import json
import re


def load_schema(schema_path):
    """Load the user profile schema from a JSON file."""
    with open(schema_path, 'r', encoding='utf-8') as f:
        return json.load(f)


def create_empty_profile():
    """
    Create an empty user profile with all fields set to null/empty.
    This represents a user we know nothing about yet.
    """
    return {
        "demographics": {
            "population": None,
            "identity_factors": [],
            "language": None,
            "pronouns": None
        },
        "logistics": {
            "zipcode": None,
            "region": None,
            "profession": None,
            "accessibility_needs": [],
            "insurance": None,
            "treatment_history": None
        },
        "status": {
            "current_state": None,
            "crisis_level": None,
            "temporary_factors": []
        },
        "clinical": {
            "primary_focus": None,
            "substances": []
        },
        "preferences": {
            "setting": None,
            "therapy_approach": None,
            "scheduling": [],
            "barriers": [],
            "contact_channel": None
        }
    }


def extract_profile_updates(schema, user_input):
    """
    Scan user input against the schema and return a dict of detected profile updates.

    For 'single' type fields, returns the first matched option value.
    For 'multi' type fields, returns a list of all matched option values.
    For 'extracted' type fields (zipcode, region, treatment_history), uses
    pattern matching or returns raw text snippets.

    Args:
        schema: The loaded profile schema dict.
        user_input: The user's message text.

    Returns:
        dict: Nested dict mirroring the profile structure, containing only
              fields where matches were found.
    """
    input_lower = user_input.lower()
    updates = {}

    for category_name, category in schema.items():
        category_updates = {}

        for field_name, field_def in category.items():
            field_type = field_def.get("type")

            if field_type == "extracted":
                # Special handling for pattern-based or free-text fields
                value = _extract_field(field_name, field_def, user_input, input_lower)
                if value is not None:
                    category_updates[field_name] = value

            elif field_type in ("single", "multi"):
                matches = []
                for option in field_def.get("options", []):
                    for keyword in option.get("keywords", []):
                        if keyword and keyword.lower() in input_lower:
                            matches.append(option["value"])
                            break  # one keyword match per option is enough

                if matches:
                    if field_type == "single":
                        category_updates[field_name] = matches[0]
                    else:
                        category_updates[field_name] = matches

        if category_updates:
            updates[category_name] = category_updates

    return updates


def _extract_field(field_name, field_def, user_input, input_lower):
    """Handle extraction for non-option fields like zipcode and treatment_history."""
    if field_name == "zipcode":
        pattern = field_def.get("pattern", r"\b\d{5}\b")
        match = re.search(pattern, user_input)
        if match:
            return match.group()
        return None

    if field_name == "region":
        # Region is typically set explicitly or by the LLM, not keyword-matched.
        # We do a lightweight check for common geographic indicators.
        geo_patterns = [
            r"\bin\s+([A-Z][a-z]+(?:\s+[A-Z][a-z]+)*)",  # "in Boston", "in Pocahontas County"
            r"\bnear\s+([A-Z][a-z]+(?:\s+[A-Z][a-z]+)*)",  # "near Springfield"
            r"\bfrom\s+([A-Z][a-z]+(?:\s+[A-Z][a-z]+)*)",  # "from Cambridge"
        ]
        for pattern in geo_patterns:
            match = re.search(pattern, user_input)
            if match:
                return match.group(1)
        return None

    if field_name == "treatment_history":
        history_keywords = ["rehab", "treatment before", "been to", "tried",
                            "previous treatment", "went to", "was in",
                            "12-step", "residential before", "relapsed"]
        for keyword in history_keywords:
            if keyword in input_lower:
                return user_input  # store the raw message as context
        return None

    return None


def merge_profile(profile, updates):
    """
    Merge new updates into the existing profile.

    - For 'single' fields (non-list values): new values overwrite old ones.
    - For 'multi' fields (list values): new values are appended (no duplicates).
    - None values in updates are ignored (don't clear existing data).

    Args:
        profile: The current user profile dict (modified in place).
        updates: The updates dict from extract_profile_updates().

    Returns:
        dict: The updated profile (same object as input).
    """
    for category_name, category_updates in updates.items():
        if category_name not in profile:
            continue

        for field_name, new_value in category_updates.items():
            if field_name not in profile[category_name]:
                continue

            if new_value is None:
                continue

            existing = profile[category_name][field_name]

            if isinstance(existing, list) and isinstance(new_value, list):
                # Append new values, skip duplicates
                for v in new_value:
                    if v not in existing:
                        existing.append(v)
            elif isinstance(existing, list) and not isinstance(new_value, list):
                # Single value going into a list field
                if new_value not in existing:
                    existing.append(new_value)
            else:
                # Single value field: overwrite
                profile[category_name][field_name] = new_value

    return profile


def profile_to_summary(profile):
    """
    Convert a user profile dict into a concise text summary for injection
    into the system prompt. Only includes fields that have been filled in.

    Returns:
        str: A human-readable summary, or empty string if profile is empty.
    """
    lines = []

    category_labels = {
        "demographics": "Demographics",
        "logistics": "Logistics & History",
        "status": "Current Status",
        "clinical": "Clinical Needs",
        "preferences": "Preferences & Barriers"
    }

    for category_name, category_label in category_labels.items():
        category = profile.get(category_name, {})
        category_lines = []

        for field_name, value in category.items():
            if value is None:
                continue
            if isinstance(value, list) and len(value) == 0:
                continue

            # Format the field name nicely
            display_name = field_name.replace("_", " ").title()

            if isinstance(value, list):
                category_lines.append(f"  - {display_name}: {', '.join(str(v) for v in value)}")
            else:
                category_lines.append(f"  - {display_name}: {value}")

        if category_lines:
            lines.append(f"[{category_label}]")
            lines.extend(category_lines)

    if not lines:
        return ""

    header = (
        "USER PROFILE (already collected — DO NOT ask the user again for any of these details):\n"
    )
    return header + "\n".join(lines)