File size: 3,891 Bytes
030876e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import json
import os
import re

import torch
from datasets import Dataset
from unsloth import FastLanguageModel

os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

DATA_PATH = "/home/mshahidul/readctrl/data/finetuning_data/finetune_dataset_subclaim_support_v2_sft_prompt.json"
MODEL_PATH = "/home/mshahidul/readctrl_model/qwen3-8B_subclaims-verifier_lora_nonreasoning"
OUTPUT_PATH = "/home/mshahidul/readctrl/results/qwen3-8B_subclaims_verifier_test_predictions.jsonl"
SUMMARY_PATH = "/home/mshahidul/readctrl/results/qwen3-8B_subclaims_verifier_test_summary.json"


def normalize_label(text: str) -> str:
    if text is None:
        return "unknown"
    cleaned = text.strip().lower()
    cleaned = cleaned.replace("\n", " ").strip()
    if "not_supported" in cleaned:
        return "not_supported"
    if "not supported" in cleaned:
        return "not_supported"
    first = re.split(r"\s+", cleaned)[0].strip(".,:;")
    if first in {"supported", "not_supported"}:
        return first
    if "supported" in cleaned:
        return "supported"
    return "unknown"


def get_turn(conversations, role: str) -> str:
    for turn in conversations:
        if turn.get("from") == role:
            return turn.get("content", "")
    return ""


def main() -> None:
    if not torch.cuda.is_available():
        raise RuntimeError("CUDA is not available. Please run on a GPU.")

    with open(DATA_PATH, "r") as f:
        data = json.load(f)

    dataset = Dataset.from_list(data)
    split_dataset = dataset.train_test_split(test_size=0.2, seed=3407, shuffle=True)
    test_data = split_dataset["test"]

    model, tokenizer = FastLanguageModel.from_pretrained(
        model_name=MODEL_PATH,
        max_seq_length=8192,
        load_in_4bit=False,
    )
    FastLanguageModel.for_inference(model)

    total = len(test_data)
    correct = 0

    with open(OUTPUT_PATH, "w") as out_f:
        for idx, item in enumerate(test_data):
            user_text = get_turn(item["conversations"], "user")
            gold_text = get_turn(item["conversations"], "assistant")
            gold_label = normalize_label(gold_text)

            messages = [{"role": "user", "content": user_text}]
            input_text = tokenizer.apply_chat_template(
                messages,
                tokenize=False,
                add_generation_prompt=True,
            )
            inputs = tokenizer([input_text], return_tensors="pt").to("cuda")

            with torch.no_grad():
                generated = model.generate(
                    **inputs,
                    max_new_tokens=20,
                    do_sample=False,
                    use_cache=True,
                    pad_token_id=tokenizer.eos_token_id,
                )

            gen_text = tokenizer.decode(
                generated[0][inputs["input_ids"].shape[-1]:],
                skip_special_tokens=True,
            )
            pred_label = normalize_label(gen_text)
            is_correct = pred_label == gold_label
            correct += int(is_correct)

            record = {
                "index": idx,
                "label": gold_label,
                "prediction": pred_label,
                "correct": is_correct,
                "raw_output": gen_text.strip(),
            }
            out_f.write(json.dumps(record, ensure_ascii=False) + "\n")

            if (idx + 1) % 100 == 0:
                print(f"Processed {idx + 1}/{total}")

    accuracy = correct / total if total else 0.0
    summary = {
        "total": total,
        "correct": correct,
        "accuracy": accuracy,
    }
    with open(SUMMARY_PATH, "w") as f:
        json.dump(summary, f, ensure_ascii=False, indent=2)

    print(f"Accuracy: {accuracy:.4f}")
    print(f"Saved predictions: {OUTPUT_PATH}")
    print(f"Saved summary: {SUMMARY_PATH}")


if __name__ == "__main__":
    main()