readCtrl_lambda / code /fine_tune_sft_dpo /qwen3_infer_bn.py
shahidul034
"Update readCtrl repo"
93694bb
"""
Run inference with the finetuned Bangla Qwen3 model on test_bn.json
and save the generation results under results/bn.
"""
import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "5"
import argparse
import json
import os
import re
from datetime import datetime
from typing import Any, Dict, List
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
def strip_think_blocks(text: str) -> str:
"""Remove <think>...</think> reasoning blocks from model output."""
cleaned = re.sub(r"<think>.*?</think>", "", text, flags=re.DOTALL).strip()
return cleaned if cleaned else text
# Paths (keep in sync with qwen3-finetune_bn.py)
BASE_DIR = "/home/mshahidul/readctrl/code/fine_tune_sft_dpo"
MODEL_SAVE_DIR = "/home/mshahidul/readctrl/code/fine_tune_sft_dpo/model/bn"
PROMPT_DIR = os.path.join(BASE_DIR, "prompt_bn")
TEST_JSON = "/home/mshahidul/readctrl/code/fine_tune_sft_dpo/dataset/bn/test_bn.json"
RESULTS_DIR = os.path.join(BASE_DIR, "results", "bn")
SOURCE_LANG = "Bangla"
LABEL_TO_PROMPT_FILE = {
"low_health_literacy": "prompt_low",
"intermediate_health_literacy": "prompt_intermediate",
"proficient_health_literacy": "prompt_proficient",
}
def load_prompts() -> Dict[str, str]:
"""Load prompt templates from prompt_bn directory."""
prompts: Dict[str, str] = {}
for label, filename in LABEL_TO_PROMPT_FILE.items():
path = os.path.join(PROMPT_DIR, filename)
if os.path.isfile(path):
with open(path, "r", encoding="utf-8") as f:
prompts[label] = f.read()
else:
raise FileNotFoundError(f"Prompt file not found: {path}")
return prompts
def build_user_message(
prompt_template: str,
full_text: str,
gold_summary: str,
source_lang: str = SOURCE_LANG,
) -> str:
"""Fill prompt template with full_text, gold_summary, source_lang."""
return (
prompt_template.replace("{full_text}", full_text)
.replace("{gold_summary}", gold_summary)
.replace("{source_lang}", source_lang)
)
def parse_args() -> argparse.Namespace:
p = argparse.ArgumentParser(
description="Run inference with finetuned Qwen3-4B Bangla model on test_bn.json."
)
p.add_argument(
"--max-new-tokens",
type=int,
default=512,
help="Maximum number of new tokens to generate per sample.",
)
p.add_argument(
"--temperature",
type=float,
default=0.7,
help="Sampling temperature.",
)
p.add_argument(
"--top-p",
type=float,
default=0.9,
help="Top-p (nucleus) sampling value.",
)
p.add_argument(
"--output-json",
type=str,
default="test_bn_qwen3-4B_sft_inference.json",
help=(
"Output JSON filename (saved under results/bn). "
"If it already exists, it will be overwritten."
),
)
return p.parse_args()
def load_model_and_tokenizer(model_dir: str):
"""Load the merged finetuned model and tokenizer for inference."""
if not os.path.isdir(model_dir):
raise FileNotFoundError(
f"Finetuned model directory not found: {model_dir}. "
"Make sure qwen3-finetune_bn.py was run with model saving enabled."
)
print(f"Loading tokenizer from {model_dir}")
tokenizer = AutoTokenizer.from_pretrained(model_dir)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
print(f"Loading model from {model_dir}")
if torch.cuda.is_available():
model = AutoModelForCausalLM.from_pretrained(
model_dir,
torch_dtype=torch.bfloat16,
device_map="auto",
)
else:
model = AutoModelForCausalLM.from_pretrained(model_dir)
model.eval()
return model, tokenizer
def run_inference(
model,
tokenizer,
test_items: List[Dict[str, Any]],
prompts: Dict[str, str],
max_new_tokens: int,
temperature: float,
top_p: float,
) -> List[Dict[str, Any]]:
"""Generate adapted texts for each test item."""
results: List[Dict[str, Any]] = []
device = next(model.parameters()).device
for idx, item in enumerate(test_items):
label = item.get("label")
fulltext = item.get("fulltext", "")
summary = item.get("summary", "")
if not fulltext or label not in prompts:
# Keep the original item, but note that generation was skipped.
out_item = dict(item)
out_item["model_gen_text"] = ""
out_item["model_gen_skipped"] = True
results.append(out_item)
continue
user_msg = build_user_message(prompts[label], fulltext, summary)
messages = [{"role": "user", "content": user_msg}]
text = tokenizer.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=False,
enable_thinking=False,
)
inputs = tokenizer(text, return_tensors="pt").to(device)
input_len = inputs["input_ids"].shape[-1]
with torch.no_grad():
generated_ids = model.generate(
**inputs,
max_new_tokens=max_new_tokens,
do_sample=True,
temperature=temperature,
top_p=top_p,
pad_token_id=tokenizer.eos_token_id,
)
gen_ids = generated_ids[0, input_len:]
gen_text = tokenizer.decode(gen_ids, skip_special_tokens=True).strip()
gen_text = strip_think_blocks(gen_text)
out_item = dict(item)
out_item["model_gen_text"] = gen_text
out_item["model_name"] = "qwen3-4B_sft_bn"
out_item["model_max_new_tokens"] = max_new_tokens
out_item["model_temperature"] = temperature
out_item["model_top_p"] = top_p
results.append(out_item)
if (idx + 1) % 10 == 0:
print(f"Processed {idx + 1} / {len(test_items)} samples")
return results
def main():
args = parse_args()
os.makedirs(RESULTS_DIR, exist_ok=True)
print("Loading prompts from", PROMPT_DIR)
prompts = load_prompts()
print("Loading test data from", TEST_JSON)
with open(TEST_JSON, "r", encoding="utf-8") as f:
test_items = json.load(f)
print(f"Test samples: {len(test_items)}")
model, tokenizer = load_model_and_tokenizer(MODEL_SAVE_DIR)
print(
f"Running inference with max_new_tokens={args.max_new_tokens}, "
f"temperature={args.temperature}, top_p={args.top_p}"
)
results = run_inference(
model=model,
tokenizer=tokenizer,
test_items=test_items,
prompts=prompts,
max_new_tokens=args.max_new_tokens,
temperature=args.temperature,
top_p=args.top_p,
)
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
output_filename = args.output_json
if not output_filename.endswith(".json"):
output_filename += ".json"
output_path = os.path.join(RESULTS_DIR, output_filename)
# If the filename already exists, append a timestamp to avoid silent overwrite.
if os.path.exists(output_path):
name, ext = os.path.splitext(output_filename)
output_filename = f"{name}_{timestamp}{ext}"
output_path = os.path.join(RESULTS_DIR, output_filename)
print("Saving results to", output_path)
with open(output_path, "w", encoding="utf-8") as f:
json.dump(results, f, ensure_ascii=False, indent=2)
print("Done.")
if __name__ == "__main__":
main()