shahidul034's picture
Add files using upload-large-folder tool
c7a6fe6 verified
import json
from pathlib import Path
from openai import OpenAI
from datasets import load_dataset
from transformers import AutoTokenizer
from unsloth.chat_templates import get_chat_template
# Configuration
API_BASE = "http://172.16.34.22:8086/v1"
MODEL_PATH = "sc"
TOKENIZER_NAME = "meta-llama/Llama-3.1-8B-Instruct"
DATASET_FILE = Path("/home/mshahidul/readctrl/data/finetuning_data/finetune_dataset_subclaim_support_v2.json")
TEXT_VARIANT = "hard_text"
# 1. Initialize OpenAI Client
client = OpenAI(api_key="EMPTY", base_url=API_BASE)
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_NAME)
tokenizer = get_chat_template(tokenizer, chat_template="llama-3.1")
def render_chat_prompt(user_prompt: str) -> str:
messages = [{"role": "user", "content": user_prompt}]
template = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
)
import ipdb; ipdb.set_trace()
print(template)
return template
def build_user_prompt(text: str, subclaims: list[str]) -> str:
numbered_subclaims = "\n".join(f"{idx + 1}. {s}" for idx, s in enumerate(subclaims))
return (
"You are a medical evidence checker.\n"
"Given a medical passage and a list of subclaims, return labels for each "
"subclaim in the same order.\n\n"
"Allowed labels: supported, not_supported.\n"
"Output format: a JSON array of strings only.\n\n"
f"Medical text:\n{text}\n\n"
f"Subclaims:\n{numbered_subclaims}"
)
def main():
# 2. Load the original dataset
raw_dataset = load_dataset("json", data_files=str(DATASET_FILE), split="train")
# 3. Re-create the test split (using your same seed/ratio)
splits = raw_dataset.train_test_split(test_size=0.1, seed=3407, shuffle=True)
test_split = splits["test"]
print(f"Running inference on {len(test_split)} samples...")
results = []
for row in test_split:
for item in row.get("items", []):
text = item.get(TEXT_VARIANT, "").strip()
subclaims = [s["subclaim"] for s in item.get("subclaims", [])]
gold_labels = [s["label"] for s in item.get("subclaims", [])]
# print("--------------------------------")
# print(text)
# print(subclaims)
# print(gold_labels)
# print("--------------------------------")
if not text or not subclaims:
continue
# 4. Render Llama chat template locally and request inference from vLLM.
prompt = render_chat_prompt(build_user_prompt(text, subclaims))
response = client.completions.create(
model=MODEL_PATH,
prompt=prompt,
temperature=0, # Keep it deterministic
max_tokens=256
)
pred_text = response.choices[0].text.strip()
print(f"--- Sample ---")
print(f"Pred: {pred_text}")
print(f"Gold: {gold_labels}")
results.append({
"predicted": pred_text,
"gold": gold_labels
})
# Save results
with open("inference_results.json", "w") as f:
json.dump(results, f, indent=4)
if __name__ == "__main__":
main()