File size: 2,274 Bytes
c7a6fe6 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 | import json
from pathlib import Path
# Input file (synthetic subclaims dataset)
DATA_PATH = Path(
"/home/mshahidul/readctrl/data/extracting_subclaim/synthetic_subclaims_first200.json"
)
OUTPUT_PATH = Path(
"/home/mshahidul/readctrl/data/finetuning_data/dataset_for_sft_support_check_list_new.json"
)
def training_prompt(medical_text, subclaims, labels):
numbered_subclaims = "\n".join(
[f"{idx + 1}. {claim}" for idx, claim in enumerate(subclaims)]
)
system_prompt = f"""
You are an expert medical adjudicator. Determine if the 'Medical Passage' contains the core factual information of each 'Subclaim', even if the passage uses simpler language or layperson terms.
Rules:
- Label 'supported' if the essential meaning is present.
- Label 'not_supported' only if the information is missing or contradicted.
Output: JSON array of strings ['supported', 'not_supported', ...]
Medical text:
{medical_text}
Subclaims:
{numbered_subclaims}
"""
conversation = {}
conversation["conversations"] = (
{"from": "user", "content": system_prompt},
{"from": "assistant", "content": json.dumps(labels, ensure_ascii=False)},
)
return conversation
def load_conversation_dataset(data_path=DATA_PATH):
with Path(data_path).open("r", encoding="utf-8") as f:
raw_data = json.load(f)
formatted_data = []
for record in raw_data:
generated = record.get("generated", {})
medical_text = generated.get("passage", "")
raw_subclaims = generated.get("subclaims", [])
subclaims = []
labels = []
for subclaim in raw_subclaims:
claim_text = subclaim.get("claim_text", "").strip()
if not claim_text:
continue
subclaims.append(claim_text)
labels.append(subclaim.get("label", "not_supported"))
if not medical_text or not subclaims:
continue
formatted_data.append(training_prompt(medical_text, subclaims, labels))
return formatted_data
# Example usage:
dataset_for_sft = load_conversation_dataset()
with OUTPUT_PATH.open("w", encoding="utf-8") as f:
json.dump(dataset_for_sft, f, ensure_ascii=False, indent=2)
print(len(dataset_for_sft))
print(dataset_for_sft[0]) |