readctrl / code /text_classifier /test_saved_dspy_vllm_gen_text_only.py
shahidul034's picture
Add files using upload-large-folder tool
1db7196 verified
import argparse
import json
import os
import traceback
import urllib.error
import urllib.request
import dspy
from dspy.evaluate import Evaluate
DEFAULT_API_BASE = "http://172.16.34.22:8040/v1"
DEFAULT_MODEL_PATH = (
"/home/mshahidul/readctrl/code/text_classifier/dspy_model/vllm-Meta-Llama-3.1-8B-Instruct_teacher-gpt5_v1/model.json"
)
DEFAULT_TEST_PATH = "/home/mshahidul/readctrl/code/text_classifier/data/verified_combined_0-80_clean200.json"
DEFAULT_OUTPUT_PATH = (
"/home/mshahidul/readctrl/code/text_classifier/accuracy/"
"vllm-llama-3.1-8b-awq-int4_teacher-gpt5_v1_clean200_eval.json"
)
class HealthLiteracySignature(dspy.Signature):
generated_text = dspy.InputField(
desc="A version of the source text rewritten for a specific audience."
)
literacy_label = dspy.OutputField(
desc=(
"Classification: low_health_literacy (simple words, no jargon), "
"intermediate_health_literacy (moderate technicality), or "
"proficient_health_literacy (highly technical/original level)."
)
)
class HealthLiteracyClassifier(dspy.Module):
def __init__(self):
super().__init__()
self.classifier = dspy.ChainOfThought(HealthLiteracySignature)
def forward(self, generated_text):
return self.classifier(generated_text=generated_text)
def parse_args():
parser = argparse.ArgumentParser(
description="Load a saved DSPy model and evaluate on test set."
)
parser.add_argument("--model-path", default=DEFAULT_MODEL_PATH)
parser.add_argument("--test-path", default=DEFAULT_TEST_PATH)
parser.add_argument(
"--api-base",
default=os.environ.get("VLLM_API_BASE", DEFAULT_API_BASE),
)
parser.add_argument("--num-threads", type=int, default=1)
parser.add_argument("--output-path", default=DEFAULT_OUTPUT_PATH)
parser.add_argument(
"--provide-traceback",
action="store_true",
help="Print full traceback if runtime error happens.",
)
return parser.parse_args()
def check_api_base(api_base):
models_url = api_base.rstrip("/") + "/models"
req = urllib.request.Request(models_url, method="GET")
try:
with urllib.request.urlopen(req, timeout=5) as resp:
if resp.status >= 400:
raise RuntimeError(
f"Endpoint reachable but unhealthy: {models_url} (status={resp.status})"
)
except urllib.error.URLError as exc:
raise ConnectionError(
"Cannot reach OpenAI-compatible endpoint. "
f"api_base={api_base}. "
"Start your vLLM server or pass correct --api-base."
) from exc
def load_testset(path):
examples = []
if path.endswith(".jsonl"):
with open(path, "r") as f:
for line in f:
if not line.strip():
continue
record = json.loads(line)
example = dspy.Example(
generated_text=record["generated_text"],
literacy_label=record["literacy_label"],
).with_inputs("generated_text")
examples.append(example)
else:
with open(path, "r") as f:
records = json.load(f)
for record in records:
text = record.get("generated_text", record.get("diff_label_texts"))
label = record.get("literacy_label", record.get("label"))
if not text or not label:
continue
example = dspy.Example(
generated_text=text,
literacy_label=label,
).with_inputs("generated_text")
examples.append(example)
return examples
def health_literacy_metric(gold, pred, trace=None):
if not pred or not hasattr(pred, "literacy_label"):
return False
gold_label = str(gold.literacy_label).strip().lower()
pred_label = str(pred.literacy_label).strip().lower()
return gold_label in pred_label
def load_compiled_classifier(path):
if hasattr(dspy, "load"):
try:
return dspy.load(path)
except Exception as exc:
print(
f"[warning] dspy.load failed ({type(exc).__name__}); "
"trying module.load(...)"
)
classifier = HealthLiteracyClassifier()
try:
classifier.load(path)
except Exception as exc:
raise RuntimeError(f"Failed to load compiled model from {path}") from exc
return classifier
def main():
args = parse_args()
if not os.path.exists(args.model_path):
raise FileNotFoundError(f"Model file not found: {args.model_path}")
if not os.path.exists(args.test_path):
raise FileNotFoundError(f"Test file not found: {args.test_path}")
try:
check_api_base(args.api_base)
lm = dspy.LM(
model="openai/dspy",
api_base=args.api_base,
api_key="EMPTY",
temperature=0.0,
)
dspy.configure(lm=lm)
testset = load_testset(args.test_path)
compiled_classifier = load_compiled_classifier(args.model_path)
evaluator = Evaluate(
devset=testset,
metric=health_literacy_metric,
num_threads=args.num_threads,
display_progress=True,
)
evaluation_result = evaluator(compiled_classifier)
accuracy_score = (
float(evaluation_result.score)
if hasattr(evaluation_result, "score")
else float(evaluation_result)
)
output_data = {
"model_path": args.model_path,
"test_path": args.test_path,
"accuracy_score": accuracy_score,
"num_results": len(getattr(evaluation_result, "results", []) or []),
}
os.makedirs(os.path.dirname(args.output_path), exist_ok=True)
with open(args.output_path, "w") as f:
json.dump(output_data, f, indent=2)
print(evaluation_result)
print(json.dumps(output_data, indent=2))
except Exception as exc:
print(f"[error] {type(exc).__name__}: {exc}")
if args.provide_traceback:
traceback.print_exc()
raise
if __name__ == "__main__":
main()