import json from pathlib import Path from openai import OpenAI from datasets import load_dataset # Configuration API_BASE = "http://172.16.34.22:3090/v1" MODEL_PATH = "sc" DATASET_FILE = Path("/home/mshahidul/readctrl/data/finetuning_data/finetune_dataset_subclaim_support_v2.json") TEXT_VARIANT = "hard_text" # 1. Initialize OpenAI Client client = OpenAI(api_key="EMPTY", base_url=API_BASE) CHAT_TEMPLATE = ( "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n" "Cutting Knowledge Date: December 2023\n" "Today Date: 26 July 2024\n\n" "<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n" "{user_prompt}" "<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" ) def render_chat_prompt(user_prompt: str) -> str: return CHAT_TEMPLATE.format(user_prompt=user_prompt) def build_user_prompt(text: str, subclaims: list[str]) -> str: numbered_subclaims = "\n".join(f"{idx + 1}. {s}" for idx, s in enumerate(subclaims)) return ( "You are a medical evidence checker.\n" "Given a medical passage and a list of subclaims, return labels for each " "subclaim in the same order.\n\n" "Allowed labels: supported, not_supported.\n" "Output format: a JSON array of strings only.\n\n" f"Medical text:\n{text}\n\n" f"Subclaims:\n{numbered_subclaims}" ) def main(): # 2. Load the original dataset raw_dataset = load_dataset("json", data_files=str(DATASET_FILE), split="train") # 3. Re-create the test split (using your same seed/ratio) splits = raw_dataset.train_test_split(test_size=0.1, seed=3407, shuffle=True) test_split = splits["test"] print(f"Running inference on {len(test_split)} samples...") results = [] for row in test_split: for item in row.get("items", []): text = item.get(TEXT_VARIANT, "").strip() subclaims = [s["subclaim"] for s in item.get("subclaims", [])] gold_labels = [s["label"] for s in item.get("subclaims", [])] # print("--------------------------------") # print(text) # print(subclaims) # print(gold_labels) # print("--------------------------------") if not text or not subclaims: continue # 4. Render Llama chat template locally and request inference from vLLM. prompt = render_chat_prompt(build_user_prompt(text, subclaims)) response = client.completions.create( model=MODEL_PATH, prompt=prompt, temperature=0, # Keep it deterministic max_tokens=256 ) pred_text = response.choices[0].text.strip() print(f"--- Sample ---") print(f"Pred: {pred_text}") print(f"Gold: {gold_labels}") results.append({ "predicted": pred_text, "gold": gold_labels }) # Save results with open("inference_results.json", "w") as f: json.dump(results, f, indent=4) if __name__ == "__main__": main()