File size: 3,149 Bytes
c7a6fe6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import json
from pathlib import Path
from openai import OpenAI
from datasets import load_dataset

# Configuration
API_BASE = "http://172.16.34.22:3090/v1"
MODEL_PATH = "sc"
DATASET_FILE = Path("/home/mshahidul/readctrl/data/finetuning_data/finetune_dataset_subclaim_support_v2.json")
TEXT_VARIANT = "hard_text"

# 1. Initialize OpenAI Client
client = OpenAI(api_key="EMPTY", base_url=API_BASE)

CHAT_TEMPLATE = (
    "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n"
    "Cutting Knowledge Date: December 2023\n"
    "Today Date: 26 July 2024\n\n"
    "<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n"
    "{user_prompt}"
    "<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
)


def render_chat_prompt(user_prompt: str) -> str:
    return CHAT_TEMPLATE.format(user_prompt=user_prompt)

def build_user_prompt(text: str, subclaims: list[str]) -> str:
    numbered_subclaims = "\n".join(f"{idx + 1}. {s}" for idx, s in enumerate(subclaims))
    return (
        "You are a medical evidence checker.\n"
        "Given a medical passage and a list of subclaims, return labels for each "
        "subclaim in the same order.\n\n"
        "Allowed labels: supported, not_supported.\n"
        "Output format: a JSON array of strings only.\n\n"
        f"Medical text:\n{text}\n\n"
        f"Subclaims:\n{numbered_subclaims}"
    )

def main():
    # 2. Load the original dataset
    raw_dataset = load_dataset("json", data_files=str(DATASET_FILE), split="train")
    
    # 3. Re-create the test split (using your same seed/ratio)
    splits = raw_dataset.train_test_split(test_size=0.1, seed=3407, shuffle=True)
    test_split = splits["test"]

    print(f"Running inference on {len(test_split)} samples...")

    results = []
    for row in test_split:
        for item in row.get("items", []):
            text = item.get(TEXT_VARIANT, "").strip()
            subclaims = [s["subclaim"] for s in item.get("subclaims", [])]
            gold_labels = [s["label"] for s in item.get("subclaims", [])]
            # print("--------------------------------")
            # print(text)
            # print(subclaims)
            # print(gold_labels)
            # print("--------------------------------")

            if not text or not subclaims:
                continue

            # 4. Render Llama chat template locally and request inference from vLLM.
            prompt = render_chat_prompt(build_user_prompt(text, subclaims))
            response = client.completions.create(
                model=MODEL_PATH,
                prompt=prompt,
                temperature=0,  # Keep it deterministic
                max_tokens=256
            )

            pred_text = response.choices[0].text.strip()
            
            print(f"--- Sample ---")
            print(f"Pred: {pred_text}")
            print(f"Gold: {gold_labels}")
            
            results.append({
                "predicted": pred_text,
                "gold": gold_labels
            })

    # Save results
    with open("inference_results.json", "w") as f:
        json.dump(results, f, indent=4)

if __name__ == "__main__":
    main()