mshahidul
Initial commit of readCtrl code without large models
030876e
import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "7"
import json
import os
from datetime import datetime
import torch
from datasets import Dataset
from unsloth import FastModel
from unsloth.chat_templates import (
get_chat_template,
standardize_data_formats,
train_on_responses_only,
)
from trl import SFTConfig, SFTTrainer
model_name = "unsloth/gemma-3-4b-it"
data_path = "/home/mshahidul/readctrl/data/extracting_subclaim/bn/multiclinsum_test_en2bn_gemma(0_1000)_3396_extracted_subclaims_bn_0_end.json"
test_size = 0.2 # 1 - train_ratio (0.8)
seed = 42
run_mode = "finetune_and_eval" # "finetune_and_eval" or "eval_base_only"
save_fp16_merged = True # whether to save merged fp16 model after finetuning
# Max subclaims to request in prompts
MAX_SUBCLAIMS_FULLTEXT = 80
MAX_SUBCLAIMS_SUMMARY = 40
def get_model_size_from_name(name):
base = name.split("/")[-1]
for part in base.split("-"):
token = part.lower()
if token.endswith("b") or token.endswith("m"):
return part
return "unknown"
model_size = get_model_size_from_name(model_name)
def formatting_prompts_func(examples):
convos = examples["conversations"]
texts = [
tokenizer.apply_chat_template(
convo,
tokenize=False,
add_generation_prompt=False,
).removeprefix("<bos>")
for convo in convos
]
return {"text": texts}
def build_subclaim_user_prompt(medical_text, is_summary=False, max_subclaims=None):
"""
Build a Bangla instruction prompt for subclaim extraction.
Uses the same wording as `extraction_prompt` in `extract_bn_subclaims_vllm.py`,
with an optional cap on the number of subclaims described in the instructions.
"""
base_prompt = f"""
You are an expert medical annotator. The following text is in Bangla (Bengali).
Your task is to extract granular, factual subclaims from the provided medical text.
A subclaim is the smallest standalone factual unit that can be independently verified.
Instructions:
1. Read the Bangla medical text carefully.
2. Extract factual statements explicitly stated in the text.
3. Each subclaim must:
- Be in Bangla (same language as the input)
- Contain exactly ONE factual assertion
- Come directly from the text (no inference or interpretation)
- Preserve original wording as much as possible
- Include any negation, uncertainty, or qualifier
4. Do NOT:
- Combine multiple facts into one subclaim
- Add new information
- Translate to another language
5. Return ONLY a valid JSON array of strings.
6. Use double quotes and valid JSON formatting only (no markdown, no commentary).
Medical Text (Bangla):
{medical_text}
Return format:
[
"subclaim 1",
"subclaim 2"
]
""".strip()
# Optionally mention a maximum number of subclaims, but only in text,
# so we keep the core wording identical to the vLLM prompt.
if max_subclaims is not None:
limit_note = (
f"\n\nNote: Extract at most {max_subclaims} subclaims, prioritizing the most important factual statements."
)
return base_prompt + limit_note
return base_prompt
def build_subclaim_examples(raw_records):
"""
Build chat-style training examples for Bangla subclaim extraction.
Each record can contribute up to two examples:
- fulltext -> fulltext_subclaims
- summary -> summary_subclaims
"""
examples = []
for record in raw_records:
fulltext = (record.get("fulltext") or "").strip()
fulltext_subclaims = record.get("fulltext_subclaims") or []
summary = (record.get("summary") or "").strip()
summary_subclaims = record.get("summary_subclaims") or []
if fulltext and fulltext_subclaims:
user_prompt = build_subclaim_user_prompt(
fulltext,
is_summary=False,
max_subclaims=MAX_SUBCLAIMS_FULLTEXT,
)
assistant_content = json.dumps(fulltext_subclaims, ensure_ascii=False)
examples.append(
{
"conversations": [
{"role": "user", "content": user_prompt},
{"role": "assistant", "content": assistant_content},
],
}
)
if summary and summary_subclaims:
user_prompt = build_subclaim_user_prompt(
summary,
is_summary=True,
max_subclaims=MAX_SUBCLAIMS_SUMMARY,
)
assistant_content = json.dumps(summary_subclaims, ensure_ascii=False)
examples.append(
{
"conversations": [
{"role": "user", "content": user_prompt},
{"role": "assistant", "content": assistant_content},
],
}
)
return examples
def extract_conversation_pair(conversations):
user_prompt = ""
gold_response = ""
for message in conversations:
role = message.get("role") or message.get("from")
content = message.get("content", "")
if role == "user" and not user_prompt:
user_prompt = content
elif role == "assistant" and not gold_response:
gold_response = content
return user_prompt, gold_response
def generate_prediction(user_prompt):
prompt = tokenizer.apply_chat_template(
[{"role": "user", "content": user_prompt}],
tokenize=False,
add_generation_prompt=True,
)
inputs = tokenizer(text=prompt, return_tensors="pt").to(model.device)
with torch.inference_mode():
outputs = model.generate(
**inputs,
max_new_tokens=1024,
do_sample=False,
temperature=0.0,
use_cache=True,
)
generated_tokens = outputs[0][inputs["input_ids"].shape[1] :]
return tokenizer.decode(generated_tokens, skip_special_tokens=True).strip()
# 1. Load Model and Tokenizer
model, tokenizer = FastModel.from_pretrained(
model_name=model_name,
max_seq_length=4092,
load_in_4bit=True,
)
# 2. Data Preparation
tokenizer = get_chat_template(tokenizer, chat_template="gemma-3")
with open(data_path, "r", encoding="utf-8") as f:
raw_data = json.load(f)
raw_dataset = Dataset.from_list(raw_data)
split_dataset = raw_dataset.train_test_split(
test_size=test_size, seed=seed, shuffle=True
)
train_raw = split_dataset["train"]
test_raw = split_dataset["test"]
train_examples = build_subclaim_examples(train_raw)
train_dataset = Dataset.from_list(train_examples)
train_dataset = train_dataset.map(formatting_prompts_func, batched=True)
# 3. Optional Finetuning
if run_mode == "finetune_and_eval":
# Add LoRA adapters for finetuning
model = FastModel.get_peft_model(
model,
r=8,
target_modules=[
"q_proj",
"k_proj",
"v_proj",
"o_proj",
"gate_proj",
"up_proj",
"down_proj",
],
lora_alpha=16,
lora_dropout=0,
bias="none",
random_state=seed,
)
# Training setup
trainer = SFTTrainer(
model=model,
tokenizer=tokenizer,
train_dataset=train_dataset,
dataset_text_field="text",
max_seq_length=2048,
args=SFTConfig(
per_device_train_batch_size=2,
gradient_accumulation_steps=4,
warmup_steps=5,
max_steps=60,
learning_rate=2e-4,
fp16=not torch.cuda.is_bf16_supported(),
bf16=torch.cuda.is_bf16_supported(),
logging_steps=1,
optim="adamw_8bit",
weight_decay=0.01,
lr_scheduler_type="linear",
seed=seed,
output_dir="outputs",
report_to="none",
),
)
# Masking to train on assistant responses only
trainer = train_on_responses_only(
trainer,
instruction_part="<start_of_turn>user\n",
response_part="<start_of_turn>model\n",
)
# Execute training
save_dir = f"/home/mshahidul/readctrl_model/subclaim_support_extraction_bn/{model_name.split('/')[-1]}"
os.makedirs(save_dir, exist_ok=True)
trainer.train()
# Optional: save in float16 merged format
if save_fp16_merged:
model.save_pretrained_merged(save_dir, tokenizer, save_method="merged_16bit")
tokenizer.save_pretrained(save_dir)
elif run_mode == "eval_base_only":
# No finetuning; evaluate base model
save_dir = f"BASE_MODEL:{model_name}"
else:
raise ValueError(f"Unsupported run_mode: {run_mode}")
# 4. Test-set Inference + Accuracy
FastModel.for_inference(model)
model.eval()
model_info_dir = (
"/home/mshahidul/readctrl/code/subclaim_support_extraction/inference_data"
)
ablation_dir = (
"/home/mshahidul/readctrl/code/subclaim_support_extraction/ablation_studies"
)
os.makedirs(model_info_dir, exist_ok=True)
os.makedirs(ablation_dir, exist_ok=True)
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
model_tag = model_name.split("/")[-1].replace(".", "_")
def _parse_subclaim_list(text):
"""Best-effort parse of a JSON list of subclaims from model output."""
if not text:
return []
text = text.strip()
# Strip any trailing reasoning markup if present
if "</think>" in text:
text = text.split("</think>")[-1].strip()
start_idx = text.find("[")
end_idx = text.rfind("]") + 1
if start_idx != -1 and end_idx > start_idx:
text_slice = text[start_idx:end_idx]
else:
text_slice = text
try:
parsed = json.loads(text_slice)
if isinstance(parsed, list):
return [str(s).strip() for s in parsed if s]
except Exception:
return []
return []
def _subclaim_metrics(gold, pred):
"""Compute simple set-based precision/recall/Jaccard for subclaim lists."""
gold_set = {s.strip() for s in gold if s}
pred_set = {s.strip() for s in pred if s}
if not gold_set and not pred_set:
return 1.0, 1.0, 1.0
if not pred_set:
return 0.0, 0.0, 0.0
inter = gold_set & pred_set
union = gold_set | pred_set
precision = len(inter) / len(pred_set) if pred_set else 0.0
recall = len(inter) / len(gold_set) if gold_set else 0.0
jaccard = len(inter) / len(union) if union else 0.0
return precision, recall, jaccard
def evaluate_subclaim_mode(test_split):
"""
Evaluate subclaim extraction on the held-out split.
For each example, we prompt on fulltext and/or summary (if present)
and compare the predicted subclaim list with the gold subclaims.
"""
results = []
total_pairs = 0
sum_precision = 0.0
sum_recall = 0.0
sum_jaccard = 0.0
for idx, sample in enumerate(test_split):
sample_id = sample.get("id")
# Fulltext side
fulltext = (sample.get("fulltext") or "").strip()
fulltext_gold = sample.get("fulltext_subclaims") or []
if fulltext and fulltext_gold:
user_prompt = build_subclaim_user_prompt(
fulltext,
is_summary=False,
max_subclaims=MAX_SUBCLAIMS_FULLTEXT,
)
pred_text = generate_prediction(user_prompt)
pred_list = _parse_subclaim_list(pred_text)
precision, recall, jaccard = _subclaim_metrics(fulltext_gold, pred_list)
total_pairs += 1
sum_precision += precision
sum_recall += recall
sum_jaccard += jaccard
results.append(
{
"sample_index": idx,
"id": sample_id,
"source_type": "fulltext",
"input_text": fulltext,
"gold_subclaims": fulltext_gold,
"predicted_subclaims": pred_list,
"precision": precision,
"recall": recall,
"jaccard": jaccard,
}
)
# Summary side
summary = (sample.get("summary") or "").strip()
summary_gold = sample.get("summary_subclaims") or []
if summary and summary_gold:
user_prompt = build_subclaim_user_prompt(
summary,
is_summary=True,
max_subclaims=MAX_SUBCLAIMS_SUMMARY,
)
pred_text = generate_prediction(user_prompt)
pred_list = _parse_subclaim_list(pred_text)
precision, recall, jaccard = _subclaim_metrics(summary_gold, pred_list)
total_pairs += 1
sum_precision += precision
sum_recall += recall
sum_jaccard += jaccard
results.append(
{
"sample_index": idx,
"id": sample_id,
"source_type": "summary",
"input_text": summary,
"gold_subclaims": summary_gold,
"predicted_subclaims": pred_list,
"precision": precision,
"recall": recall,
"jaccard": jaccard,
}
)
avg_precision = sum_precision / total_pairs if total_pairs else 0.0
avg_recall = sum_recall / total_pairs if total_pairs else 0.0
avg_jaccard = sum_jaccard / total_pairs if total_pairs else 0.0
metrics = {
"mode": "bangla_subclaim_extraction",
"model_name": model_name,
"model_save_dir": save_dir,
"dataset_path": data_path,
"seed": seed,
"test_size": test_size,
"examples_evaluated": total_pairs,
"avg_precision": avg_precision,
"avg_recall": avg_recall,
"avg_jaccard": avg_jaccard,
"subclaim_score": avg_jaccard,
"timestamp": timestamp,
}
return results, metrics
results, accuracy_summary = evaluate_subclaim_mode(test_raw)
accuracy_summary["finetune_mode"] = "subclaim_extraction"
accuracy_summary["model_size"] = model_size
accuracy_summary["run_mode"] = run_mode
accuracy_summary["language"] = "bn"
predictions_path = os.path.join(
model_info_dir,
f"{model_tag}_test_inference_{timestamp}.json",
)
accuracy_path = os.path.join(
ablation_dir,
f"{model_tag}_subclaim_{model_size}_{run_mode}_{timestamp}.json",
)
with open(predictions_path, "w", encoding="utf-8") as f:
json.dump(results, f, ensure_ascii=False, indent=2)
with open(accuracy_path, "w", encoding="utf-8") as f:
json.dump(accuracy_summary, f, ensure_ascii=False, indent=2)
print(f"Saved test inference to: {predictions_path}")
print(f"Saved test metrics to: {accuracy_path}")
print(
f"Avg Jaccard (subclaim_score): {accuracy_summary.get('subclaim_score', 0.0):.4f}"
)