PR-AGENT / server /form_schema.py
Seth
Update
de686dc
"""AI-generated form field definitions per commodity — cached in SQLite for consistency."""
from __future__ import annotations
import json
import os
import sqlite3
from typing import Any
from openai import OpenAI
from server.catalog import get_commodity, summarize_row
from server.pr_lines import INTERVAL_ORDER
# Used when the model (or legacy cache) emits text fields or too few options.
GENERIC_SELECT_FALLBACK = (
"Standard / typical requirement",
"Enhanced vs baseline",
"Economy / essential only",
"Pilot or limited scope",
"Strategic priority program",
"Other — use specification notes below",
)
SCHEMA_GEN_SYSTEM = """You design procurement intake forms for a single catalogue commodity (segment → family → class → commodity).
Return ONE JSON object only (no markdown). Shape:
{
"fields": [
{
"id": "stable_snake_case_id",
"label": "Full question text shown to the user (no Q1/Q2 prefixes)",
"type": "select" | "chips" | "number",
"options": ["required for select and chips: 3–12 distinct, short option strings"],
"unit": "ONLY for type number: short suffix shown next to the input (e.g. kg, lb, mm, in, %)"
}
]
}
Rules:
- 3 to 7 fields. Labels must be clear procurement questions for THIS commodity type.
- Do NOT include: number of deliveries, delivery interval/frequency, year for scheduling, or a generic "other / free text specifications" field — the application collects those separately.
- **No open-ended typing:** NEVER use type "text" or "textarea". Users must tap choices only.
- Prefer **select** with 5–12 concise options for objectives, scope, methodology, audience, timing, risk, quality level, etc.
- Use **chips** for 3–8 mutually exclusive options when labels are short (single choice — same as select, shown as buttons).
- Use **number** only for true numeric values (counts, currency amounts, percentages, sizes, weights, dimensions).
- For **every number field**, set **"unit"** to the metric users should enter (e.g. `"kg"` for weight capacity, `"mm"` for seat depth, `"lb"` only if Imperial is explicit). Never leave unit ambiguous when the question is a measurement.
- Every option string must be self-contained (no reliance on free-form explanations). If a case might need nuance, add an option such as "Other — see specification notes below".
- Use stable `id` values (snake_case) — they are keys in saved data.
- Same commodity must always get the same structure when regenerated; the app caches by commodity code, but ids and intent must stay consistent if you see similar commodities.
"""
def ensure_form_schema_table(conn: sqlite3.Connection) -> None:
conn.execute(
"""
CREATE TABLE IF NOT EXISTS commodity_form_schemas (
commodity_code INTEGER PRIMARY KEY,
schema_json TEXT NOT NULL,
updated_at TEXT DEFAULT CURRENT_TIMESTAMP
)
"""
)
conn.commit()
def _load_cached(conn: sqlite3.Connection, commodity_code: int) -> dict[str, Any] | None:
cur = conn.cursor()
cur.execute(
"SELECT schema_json FROM commodity_form_schemas WHERE commodity_code = ?",
(commodity_code,),
)
row = cur.fetchone()
if not row:
return None
try:
return json.loads(row[0])
except json.JSONDecodeError:
return None
def _save_cache(conn: sqlite3.Connection, commodity_code: int, schema: dict[str, Any]) -> None:
conn.execute(
"""
INSERT INTO commodity_form_schemas (commodity_code, schema_json, updated_at)
VALUES (?, ?, datetime('now'))
ON CONFLICT(commodity_code) DO UPDATE SET
schema_json = excluded.schema_json,
updated_at = datetime('now')
""",
(commodity_code, json.dumps(schema, ensure_ascii=False)),
)
conn.commit()
_FALLBACK_FIELDS_RAW: list[dict[str, Any]] = [
{
"id": "primary_scope",
"label": "What is the primary scope or geography for this requirement?",
"type": "select",
"options": [
"Local / single site",
"Regional",
"National",
"International",
"Multi-region program",
"To be determined",
],
},
{
"id": "scale_band",
"label": "What scale band best matches expected volume or spend?",
"type": "chips",
"options": [
"Pilot / small",
"Medium",
"Large",
"Enterprise-wide",
"Not yet estimated",
],
},
{
"id": "compliance_focus",
"label": "Which compliance themes apply (if any)?",
"type": "select",
"options": [
"None identified yet",
"Data privacy / residency",
"Safety / quality standards",
"Financial / audit controls",
"Industry-specific regulations",
"Mixed — see specification notes",
],
},
]
def _fallback_schema() -> dict[str, Any]:
return {
"fields": [_coerce_field_selectable(dict(f)) for f in _FALLBACK_FIELDS_RAW],
"source": "fallback",
}
def _coerce_field_selectable(entry: dict[str, Any]) -> dict[str, Any]:
"""Ensure fields are selectable (select/chips) or number — never free-text."""
typ = str(entry.get("type") or "select").lower()
if typ in ("text", "textarea"):
typ = "select"
elif typ == "number":
out = {**entry, "type": "number"}
out.pop("options", None)
unit = str(out.get("unit") or "").strip()
if unit:
out["unit"] = unit[:24]
else:
out.pop("unit", None)
return out
elif typ not in ("select", "chips"):
typ = "select"
opts_raw = entry.get("options")
clean: list[str] = []
if isinstance(opts_raw, list):
clean = [str(o).strip() for o in opts_raw if str(o).strip()]
if len(clean) < 2:
clean = list(GENERIC_SELECT_FALLBACK)
return {**entry, "type": typ, "options": clean}
def _validate_and_normalize(raw: dict[str, Any]) -> dict[str, Any]:
fields_out: list[dict[str, Any]] = []
seen_ids: set[str] = set()
for f in raw.get("fields") or []:
if not isinstance(f, dict):
continue
fid = str(f.get("id") or "").strip()
label = str(f.get("label") or "").strip()
typ = str(f.get("type") or "select").lower()
if not fid or not label:
continue
if fid in seen_ids:
continue
seen_ids.add(fid)
if typ not in ("select", "number", "text", "chips", "textarea"):
typ = "select"
opts = f.get("options")
entry: dict[str, Any] = {"id": fid, "label": label, "type": typ}
if typ in ("select", "chips") and isinstance(opts, list) and opts:
entry["options"] = [str(o) for o in opts if str(o).strip()]
if typ == "number":
u = str(f.get("unit") or "").strip()
if u:
entry["unit"] = u[:24]
fields_out.append(_coerce_field_selectable(entry))
if len(fields_out) < 1:
return _fallback_schema()
return {"fields": fields_out, "source": "openai"}
def _coerce_cached_schema(cached: dict[str, Any]) -> dict[str, Any]:
"""Upgrade legacy cached schemas (text/textarea) to selectable controls."""
fields_in = cached.get("fields") or []
fields_out: list[dict[str, Any]] = []
seen_ids: set[str] = set()
for f in fields_in:
if not isinstance(f, dict):
continue
fid = str(f.get("id") or "").strip()
label = str(f.get("label") or "").strip()
if not fid or not label or fid in seen_ids:
continue
seen_ids.add(fid)
typ = str(f.get("type") or "select").lower()
entry: dict[str, Any] = {"id": fid, "label": label, "type": typ}
opts = f.get("options")
if typ in ("select", "chips") and isinstance(opts, list) and opts:
entry["options"] = [str(o) for o in opts if str(o).strip()]
if typ == "number":
u = str(f.get("unit") or "").strip()
if u:
entry["unit"] = u[:24]
fields_out.append(_coerce_field_selectable(entry))
if len(fields_out) < 1:
return _fallback_schema()
out = {**cached, "fields": fields_out, "source": cached.get("source", "cache")}
out["interval_options"] = INTERVAL_ORDER
return out
def generate_schema_with_llm(row: dict[str, Any]) -> dict[str, Any]:
api_key = os.environ.get("OPENAI_API_KEY")
if not api_key:
return _fallback_schema()
s = summarize_row(row)
user_block = json.dumps(
{
"segment_code": s.get("segment_code"),
"family_code": s.get("family_code"),
"class_code": s.get("class_code"),
"commodity_code": s.get("commodity_code"),
"path": s.get("path"),
"commodity_title": s.get("commodity_title"),
"commodity_definition": s.get("commodity_definition"),
},
ensure_ascii=False,
)
client = OpenAI(api_key=api_key)
resp = client.chat.completions.create(
model=os.environ.get("OPENAI_MODEL", "gpt-4o-mini"),
messages=[
{"role": "system", "content": SCHEMA_GEN_SYSTEM},
{"role": "user", "content": user_block},
],
temperature=0.2,
response_format={"type": "json_object"},
)
text = (resp.choices[0].message.content or "").strip()
try:
parsed = json.loads(text)
except json.JSONDecodeError:
return _fallback_schema()
return _validate_and_normalize(parsed)
def get_or_create_schema(conn: sqlite3.Connection, commodity_code: int) -> dict[str, Any]:
ensure_form_schema_table(conn)
cached = _load_cached(conn, commodity_code)
if cached and cached.get("fields"):
return _coerce_cached_schema(cached)
row = get_commodity(conn, commodity_code)
if not row:
return {"fields": [], "error": "commodity_not_found", "interval_options": INTERVAL_ORDER}
schema = generate_schema_with_llm(row)
if schema.get("fields"):
_save_cache(conn, commodity_code, schema)
schema["interval_options"] = INTERVAL_ORDER
return schema