site-intelligence-studio / src /small_model_assistant.py
Eishaan's picture
Prepare submission-ready Space build
c590d67
Raw
History Blame Contribute Delete
6.22 kB
from __future__ import annotations
import os
from typing import Iterable
import requests
from .models import EvidenceItem, SiteSelection
from .safety import assert_safe_text
DEFAULT_SMALL_MODEL_ID = "HuggingFaceTB/SmolLM2-360M-Instruct"
def build_assistant_brief(
*,
selection: SiteSelection,
evidence_rows: Iterable[EvidenceItem],
warnings: list[str],
project_type: str,
) -> str:
rows = list(evidence_rows)
fallback = _template_brief(selection, rows, warnings, project_type)
if os.getenv("ENABLE_SMALL_MODEL", "").strip() != "1":
return fallback
token = os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACEHUB_API_TOKEN")
if not token:
return fallback + "\n\n**Model status:** small-model generation skipped because no HF token is configured."
model_id = os.getenv("SMALL_MODEL_ID", DEFAULT_SMALL_MODEL_ID)
prompt = _prompt(selection, rows, warnings, project_type)
try:
text = _call_hf_inference(model_id, token, prompt)
if not text.strip():
return fallback + f"\n\n**Model status:** `{model_id}` returned no usable text; template fallback shown."
cleaned = _trim_model_text(text)
assert_safe_text(cleaned)
return (
f"**Model status:** generated with `{model_id}`. Facts come only from evidence rows below; verify all site and professional items.\n\n"
+ cleaned
)
except Exception as exc: # noqa: BLE001
return fallback + f"\n\n**Model status:** small-model generation failed ({type(exc).__name__}); template fallback shown."
def _template_brief(
selection: SiteSelection,
rows: list[EvidenceItem],
warnings: list[str],
project_type: str,
) -> str:
mode_note = (
"This is radius-based context analysis, not exact plot analysis."
if selection.selection_type == "pin_radius"
else "Treat boundary conclusions according to the uploaded/drawn source reliability."
)
lines = [
"**Model status:** template fallback active. Set `ENABLE_SMALL_MODEL=1`, `HF_TOKEN`, and optionally `SMALL_MODEL_ID` to enable a <=4B Hugging Face model.",
"",
"**Studio caption draft**",
"",
f"- Site selection mode: `{selection.selection_type}`. {mode_note}",
f"- Project type: {project_type or 'not specified'}; use this only to frame captions, not to invent design decisions.",
"- Public-data and uploaded-file findings are evidence-backed where rows exist; missing layers stay as checklist items.",
"",
"**Evidence-backed points to use**",
"",
]
lines.extend(f"- {item}" for item in _top_evidence(rows))
lines.extend(["", "**Ask / verify before final sheet**", ""])
lines.extend(f"- {item}" for item in _top_verification(rows, warnings))
result = "\n".join(lines)
assert_safe_text(result)
return result
def _top_evidence(rows: list[EvidenceItem]) -> list[str]:
useful: list[str] = []
for row in rows:
if _is_missing_data_row(row):
continue
if row.output_label in {"public_data", "computed", "cad_derived", "user_input", "site_visit_required"}:
useful.append(f"{row.id}: {row.finding} Confidence: {row.confidence}.")
if len(useful) >= 6:
break
return useful or ["No evidence rows were available; use the checklist and retry data sources."]
def _is_missing_data_row(row: EvidenceItem) -> bool:
text = f"{row.finding} {row.limitation}".lower()
return any(
phrase in text
for phrase in (
"could not be retrieved",
"unavailable in this run",
"request failed",
"retrieval failed",
)
)
def _top_verification(rows: list[EvidenceItem], warnings: list[str]) -> list[str]:
items: list[str] = []
for row in rows:
if row.verification_needed and row.verification_needed not in items:
items.append(row.verification_needed)
if len(items) >= 6:
break
for warning in warnings:
if len(items) >= 7:
break
items.append(warning)
return items or ["Verify site conditions manually before design claims."]
def _prompt(
selection: SiteSelection,
rows: list[EvidenceItem],
warnings: list[str],
project_type: str,
) -> str:
evidence_text = "\n".join(
f"{row.id} | {row.category} | {row.finding} | confidence={row.confidence} | limitation={row.limitation} | verify={row.verification_needed}"
for row in rows[:12]
)
warning_text = "\n".join(warnings[:5]) or "No additional warnings."
return f"""You are writing for an architecture student's site-analysis board.
Use only the evidence rows. Do not invent site facts, culture, demographics, laws, foundations, or final design decisions.
Do not say the boundary, soil, foundation, or design is exact, safe, or correct.
Selection mode: {selection.selection_type}
Project type: {project_type or "not specified"}
Evidence:
{evidence_text}
Warnings:
{warning_text}
Write:
1. three concise board captions;
2. five verification questions for the site visit;
3. one short uncertainty note.
"""
def _call_hf_inference(model_id: str, token: str, prompt: str) -> str:
url = f"https://api-inference.huggingface.co/models/{model_id}"
response = requests.post(
url,
headers={"Authorization": f"Bearer {token}"},
json={
"inputs": prompt,
"parameters": {
"max_new_tokens": 360,
"temperature": 0.2,
"return_full_text": False,
},
},
timeout=45,
)
response.raise_for_status()
payload = response.json()
if isinstance(payload, list) and payload:
first = payload[0]
if isinstance(first, dict):
return str(first.get("generated_text") or "")
if isinstance(payload, dict):
return str(payload.get("generated_text") or payload.get("error") or "")
return str(payload)
def _trim_model_text(text: str) -> str:
value = text.strip()
if len(value) > 3500:
value = value[:3500].rsplit("\n", 1)[0].strip()
return value