| | import json |
| | from pathlib import Path |
| | from openai import OpenAI |
| | from datasets import load_dataset |
| |
|
| | |
| | API_BASE = "http://172.16.34.22:3090/v1" |
| | MODEL_PATH = "sc" |
| | DATASET_FILE = Path("/home/mshahidul/readctrl/data/finetuning_data/finetune_dataset_subclaim_support_v2.json") |
| | TEXT_VARIANT = "hard_text" |
| |
|
| | |
| | client = OpenAI(api_key="EMPTY", base_url=API_BASE) |
| |
|
| | CHAT_TEMPLATE = ( |
| | "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n" |
| | "Cutting Knowledge Date: December 2023\n" |
| | "Today Date: 26 July 2024\n\n" |
| | "<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n" |
| | "{user_prompt}" |
| | "<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" |
| | ) |
| |
|
| |
|
| | def render_chat_prompt(user_prompt: str) -> str: |
| | return CHAT_TEMPLATE.format(user_prompt=user_prompt) |
| |
|
| | def build_user_prompt(text: str, subclaims: list[str]) -> str: |
| | numbered_subclaims = "\n".join(f"{idx + 1}. {s}" for idx, s in enumerate(subclaims)) |
| | return ( |
| | "You are a medical evidence checker.\n" |
| | "Given a medical passage and a list of subclaims, return labels for each " |
| | "subclaim in the same order.\n\n" |
| | "Allowed labels: supported, not_supported.\n" |
| | "Output format: a JSON array of strings only.\n\n" |
| | f"Medical text:\n{text}\n\n" |
| | f"Subclaims:\n{numbered_subclaims}" |
| | ) |
| |
|
| | def main(): |
| | |
| | raw_dataset = load_dataset("json", data_files=str(DATASET_FILE), split="train") |
| | |
| | |
| | splits = raw_dataset.train_test_split(test_size=0.1, seed=3407, shuffle=True) |
| | test_split = splits["test"] |
| |
|
| | print(f"Running inference on {len(test_split)} samples...") |
| |
|
| | results = [] |
| | for row in test_split: |
| | for item in row.get("items", []): |
| | text = item.get(TEXT_VARIANT, "").strip() |
| | subclaims = [s["subclaim"] for s in item.get("subclaims", [])] |
| | gold_labels = [s["label"] for s in item.get("subclaims", [])] |
| | |
| | |
| | |
| | |
| | |
| |
|
| | if not text or not subclaims: |
| | continue |
| |
|
| | |
| | prompt = render_chat_prompt(build_user_prompt(text, subclaims)) |
| | response = client.completions.create( |
| | model=MODEL_PATH, |
| | prompt=prompt, |
| | temperature=0, |
| | max_tokens=256 |
| | ) |
| |
|
| | pred_text = response.choices[0].text.strip() |
| | |
| | print(f"--- Sample ---") |
| | print(f"Pred: {pred_text}") |
| | print(f"Gold: {gold_labels}") |
| | |
| | results.append({ |
| | "predicted": pred_text, |
| | "gold": gold_labels |
| | }) |
| |
|
| | |
| | with open("inference_results.json", "w") as f: |
| | json.dump(results, f, indent=4) |
| |
|
| | if __name__ == "__main__": |
| | main() |