"""Test harness for the Prisma system prompt. Sends contrasting user messages to Llama 3.3 70B via HF Inference API, parses the dual-role JSON output, and prints scores side-by-side to inspect (a) JSON parseability, (b) score variance, (c) response cleanliness. Usage: python scripts/test_prompt.py Requires HF_TOKEN in a .env file at the repo root. """ from __future__ import annotations import json import os import sys from pathlib import Path from dotenv import load_dotenv from huggingface_hub import InferenceClient # Make the src package importable when running this script directly REPO_ROOT = Path(__file__).resolve().parent.parent sys.path.insert(0, str(REPO_ROOT)) from src.config import DEFAULT_ATTRIBUTES # noqa: E402 from src.prompt import build_system_prompt # noqa: E402 MODEL_ID = "meta-llama/Llama-3.3-70B-Instruct" #MODEL_ID = "Qwen/Qwen2.5-72B-Instruct" MAX_TOKENS = 600 TEMPERATURE = 0.7 # (label, user_message) — pairs chosen to vary along plausibly-perceptible # dimensions: precision, formality, politeness, demandingness, expertise. TEST_MESSAGES: list[tuple[str, str]] = [ ("precise", "I'll arrive at 7:03 PM sharp, having reviewed all 47 pages of the " "report beforehand."), ("vague", "yeah i guess i'll be there at like seven or whenever idk"), ("formal", "Good afternoon. I would be most grateful if you could provide a " "brief overview of the topic."), ("casual", "hey can u explain this thing real quick lol"), ("polite", "Hi! When you have a moment, could you please help me understand " "how this works? Thanks so much!"), ("demanding", "Tell me how this works. Now. Don't waste my time."), ("expert", "I'm curious about the trade-offs between in-context learning and " "fine-tuning for low-resource domain adaptation."), ("confused", "umm so like the thing... how does it work? i dont get it"), ] def query_model( client: InferenceClient, system_prompt: str, user_message: str, ) -> tuple[str | None, dict | None]: """Send a single-turn query and return (raw_output, parsed_json). parsed_json is None if JSON parsing fails. """ try: completion = client.chat_completion( model=MODEL_ID, messages=[ {"role": "system", "content": system_prompt}, {"role": "user", "content": user_message}, ], max_tokens=MAX_TOKENS, temperature=TEMPERATURE, response_format={"type": "json_object"}, ) raw = completion.choices[0].message.content except Exception as exc: print(f" [error] inference call failed: {exc}") return None, None try: return raw, json.loads(raw) except json.JSONDecodeError as exc: print(f" [warn] JSON parse failed: {exc}") print(f" [raw] {raw[:400]!r}") return raw, None def print_score_table(results: dict[str, dict]) -> None: """Print a side-by-side comparison of evaluation scores across labels.""" labels = list(results.keys()) col_width = max(8, max(len(label) for label in labels) + 2) header = f"{'attribute':<16}" + "".join(f"{l:>{col_width}}" for l in labels) print() print(header) print("-" * len(header)) for attr in DEFAULT_ATTRIBUTES: row = f"{attr:<16}" for label in labels: evaluation = results[label].get("evaluation", {}) score = evaluation.get(attr, "—") row += f"{score:>{col_width}}" print(row) print() def main() -> None: load_dotenv() token = os.getenv("HF_TOKEN") if not token: print("ERROR: HF_TOKEN not found. Check your .env file.") sys.exit(1) client = InferenceClient(token=token) system_prompt = build_system_prompt() print("=" * 72) print("PRISMA prompt test harness") print("=" * 72) print(f"Model: {MODEL_ID}") print(f"Attributes: {', '.join(DEFAULT_ATTRIBUTES)}") print(f"Messages: {len(TEST_MESSAGES)}") print() parsed_results: dict[str, dict] = {} for label, message in TEST_MESSAGES: print(f"[{label}] {message}") _raw, parsed = query_model(client, system_prompt, message) if parsed is None: print() continue response = parsed.get("response", "(missing 'response' field)") preview = response[:120] + ("..." if len(response) > 120 else "") print(f" response: {preview}") print(f" evaluation: {parsed.get('evaluation', {})}") parsed_results[label] = parsed print() if parsed_results: print_score_table(parsed_results) print(f"Parseable: {len(parsed_results)}/{len(TEST_MESSAGES)}") if __name__ == "__main__": main()