File size: 5,709 Bytes
030876e | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 | import os
import json
import tqdm
import argparse
from openai import OpenAI
# -----------------------------
# CONFIGURATION
# -----------------------------
MODEL_NAME = "/home/mshahidul/readctrl_model/full_model/qwen3-32B_subclaims_BF16_merged"
API_URL = "http://localhost:8015/v1"
API_KEY = "EMPTY"
client = OpenAI(base_url=API_URL, api_key=API_KEY)
# -----------------------------
# SUBCLAIM EXTRACTION PROMPT
# -----------------------------
def extraction_prompt(medical_text: str) -> str:
return f"""
You are an expert medical annotator. Extract granular, factual subclaims.
A subclaim is the smallest standalone factual unit that can be independently verified.
Rules:
- Use only information explicitly present in the text.
- Do not infer or hallucinate.
- Subclaims must be atomic and factual.
- Return ONLY a JSON list of strings.
Medical Text:
{medical_text}
Return output as:
[
"subclaim 1",
"subclaim 2",
...
]
"""
# -----------------------------
# INFERENCE FUNCTION
# -----------------------------
def infer_subclaims(medical_text: str, temperature: float = 0.2) -> list:
if not medical_text or medical_text.strip() == "":
return []
final_prompt = extraction_prompt(medical_text)
try:
response = client.chat.completions.create(
model=MODEL_NAME,
messages=[{"role": "user", "content": final_prompt}],
max_tokens=1000,
temperature=temperature,
top_p=0.9,
)
res = response.choices[0].message.content.strip()
res = res.split("</think>")[-1].strip()
# try parse JSON
try:
return json.loads(res)
except:
return res
except Exception as e:
print(f"API error: {e}")
return []
# -----------------------------
# MAIN
# -----------------------------
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--file1", type=str, required=True,
help="Path to synthetic_data_es_raw_592.json")
parser.add_argument("--file2", type=str, required=True,
help="Path to multiclinsum_gs_train_es.json")
parser.add_argument("--start_index", type=int, default=0,
help="Start index for processing")
parser.add_argument("--end_index", type=int, default=-1,
help="End index for processing (exclusive). -1 = until end")
args = parser.parse_args()
FILE1 = args.file1
FILE2 = args.file2
SAVE_FOLDER = "/home/mshahidul/readctrl/data/extracting_subclaim"
os.makedirs(SAVE_FOLDER, exist_ok=True)
# Output filename includes the range
OUTPUT_FILE = os.path.join(
SAVE_FOLDER,
f"extracted_subclaims_{args.start_index}_{args.end_index}.json"
)
# -----------------------------
# Load files
# -----------------------------
print("Loading input files...")
with open(FILE1, "r") as f:
file1_data = {x["id"]: x for x in json.load(f)}
with open(FILE2, "r") as f:
file2_data = {x["id"]: x for x in json.load(f)}
# -----------------------------
# Merge and slice by range
# -----------------------------
all_ids = sorted(list(set(file1_data.keys()) | set(file2_data.keys())))
total_items = len(all_ids)
start = args.start_index
end = args.end_index if args.end_index != -1 else total_items
slice_ids = all_ids[start:end]
print(f"Total IDs: {total_items}")
print(f"Processing range: {start} → {end} (count={len(slice_ids)})")
# -----------------------------
# Resume mode
# -----------------------------
result = []
if os.path.exists(OUTPUT_FILE):
try:
with open(OUTPUT_FILE, "r") as f:
result = json.load(f)
except:
result = []
existing_ids = {r["id"] for r in result}
# -----------------------------
# Process items
# -----------------------------
for _id in tqdm.tqdm(slice_ids):
if _id in existing_ids:
continue
# FILE1 text
easy_text = inter_text = hard_text = ""
if _id in file1_data:
rv = file1_data[_id]["readability_versions"]
easy_text = rv.get("easy", {}).get("text", "")
inter_text = rv.get("intermediate", {}).get("text", "")
hard_text = rv.get("hard", {}).get("text", "")
# FILE2 text
fulltext = summary = ""
if _id in file2_data:
fulltext = file2_data[_id].get("fulltext", "")
summary = file2_data[_id].get("summary", "")
# inference
easy_sub = infer_subclaims(easy_text)
inter_sub = infer_subclaims(inter_text)
hard_sub = infer_subclaims(hard_text)
fulltext_sub = infer_subclaims(fulltext)
summary_sub = infer_subclaims(summary)
# append
result.append({
"id": _id,
"easy_text": easy_text,
"easy_subclaims": easy_sub,
"intermediate_text": inter_text,
"intermediate_subclaims": inter_sub,
"hard_text": hard_text,
"hard_subclaims": hard_sub,
"fulltext": fulltext,
"fulltext_subclaims": fulltext_sub,
"summary": summary,
"summary_subclaims": summary_sub
})
# save frequently
if len(result) % 20 == 0:
with open(OUTPUT_FILE, "w") as f:
json.dump(result, f, indent=4, ensure_ascii=False)
# final save
with open(OUTPUT_FILE, "w") as f:
json.dump(result, f, indent=4, ensure_ascii=False)
print(f"Done! Saved to: {OUTPUT_FILE}")
|