Spaces:
Running
Running
File size: 4,832 Bytes
9912839 ba58dc4 9912839 ba58dc4 9912839 ba58dc4 9912839 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 | """Test harness for the Prisma system prompt.
Sends contrasting user messages to Llama 3.3 70B via HF Inference API,
parses the dual-role JSON output, and prints scores side-by-side to
inspect (a) JSON parseability, (b) score variance, (c) response cleanliness.
Usage:
python scripts/test_prompt.py
Requires HF_TOKEN in a .env file at the repo root.
"""
from __future__ import annotations
import json
import os
import sys
from pathlib import Path
from dotenv import load_dotenv
from huggingface_hub import InferenceClient
# Make the src package importable when running this script directly
REPO_ROOT = Path(__file__).resolve().parent.parent
sys.path.insert(0, str(REPO_ROOT))
from src.config import DEFAULT_ATTRIBUTES # noqa: E402
from src.prompt import build_system_prompt # noqa: E402
MODEL_ID = "meta-llama/Llama-3.3-70B-Instruct"
#MODEL_ID = "Qwen/Qwen2.5-72B-Instruct"
MAX_TOKENS = 600
TEMPERATURE = 0.7
# (label, user_message) — pairs chosen to vary along plausibly-perceptible
# dimensions: precision, formality, politeness, demandingness, expertise.
TEST_MESSAGES: list[tuple[str, str]] = [
("precise",
"I'll arrive at 7:03 PM sharp, having reviewed all 47 pages of the "
"report beforehand."),
("vague",
"yeah i guess i'll be there at like seven or whenever idk"),
("formal",
"Good afternoon. I would be most grateful if you could provide a "
"brief overview of the topic."),
("casual",
"hey can u explain this thing real quick lol"),
("polite",
"Hi! When you have a moment, could you please help me understand "
"how this works? Thanks so much!"),
("demanding",
"Tell me how this works. Now. Don't waste my time."),
("expert",
"I'm curious about the trade-offs between in-context learning and "
"fine-tuning for low-resource domain adaptation."),
("confused",
"umm so like the thing... how does it work? i dont get it"),
]
def query_model(
client: InferenceClient,
system_prompt: str,
user_message: str,
) -> tuple[str | None, dict | None]:
"""Send a single-turn query and return (raw_output, parsed_json).
parsed_json is None if JSON parsing fails.
"""
try:
completion = client.chat_completion(
model=MODEL_ID,
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_message},
],
max_tokens=MAX_TOKENS,
temperature=TEMPERATURE,
response_format={"type": "json_object"},
)
raw = completion.choices[0].message.content
except Exception as exc:
print(f" [error] inference call failed: {exc}")
return None, None
try:
return raw, json.loads(raw)
except json.JSONDecodeError as exc:
print(f" [warn] JSON parse failed: {exc}")
print(f" [raw] {raw[:400]!r}")
return raw, None
def print_score_table(results: dict[str, dict]) -> None:
"""Print a side-by-side comparison of evaluation scores across labels."""
labels = list(results.keys())
col_width = max(8, max(len(label) for label in labels) + 2)
header = f"{'attribute':<16}" + "".join(f"{l:>{col_width}}" for l in labels)
print()
print(header)
print("-" * len(header))
for attr in DEFAULT_ATTRIBUTES:
row = f"{attr:<16}"
for label in labels:
evaluation = results[label].get("evaluation", {})
score = evaluation.get(attr, "—")
row += f"{score:>{col_width}}"
print(row)
print()
def main() -> None:
load_dotenv()
token = os.getenv("HF_TOKEN")
if not token:
print("ERROR: HF_TOKEN not found. Check your .env file.")
sys.exit(1)
client = InferenceClient(token=token)
system_prompt = build_system_prompt()
print("=" * 72)
print("PRISMA prompt test harness")
print("=" * 72)
print(f"Model: {MODEL_ID}")
print(f"Attributes: {', '.join(DEFAULT_ATTRIBUTES)}")
print(f"Messages: {len(TEST_MESSAGES)}")
print()
parsed_results: dict[str, dict] = {}
for label, message in TEST_MESSAGES:
print(f"[{label}] {message}")
_raw, parsed = query_model(client, system_prompt, message)
if parsed is None:
print()
continue
response = parsed.get("response", "(missing 'response' field)")
preview = response[:120] + ("..." if len(response) > 120 else "")
print(f" response: {preview}")
print(f" evaluation: {parsed.get('evaluation', {})}")
parsed_results[label] = parsed
print()
if parsed_results:
print_score_table(parsed_results)
print(f"Parseable: {len(parsed_results)}/{len(TEST_MESSAGES)}")
if __name__ == "__main__":
main() |