File size: 5,930 Bytes
9c6961c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 | import argparse
import os
import json
import sys
sys.path.append(os.path.abspath('/home/mshahidul/'))
from gpu_selection import _gpu_selection_
parser = argparse.ArgumentParser(description="Readability Controlled Generation")
parser.add_argument("--cuda", type=str, default="3")
parser.add_argument("--model_name", type=str, default="/home/mshahidul/readctrl/finetuned_models/es_synthetic_data_creation_Qwen3_14B_v2")
parser.add_argument("--temperature", type=float, default=0.1)
args = parser.parse_args()
model_name = args.model_name
temperature = args.temperature
if args.cuda is not None:
os.environ["CUDA_VISIBLE_DEVICES"] = args.cuda
print(f"🎮🎮 Using CUDA device: {args.cuda}")
else:
_gpu_selection_()
prompts = {
"easy": '''
You are an assistant that rewrites Spanish texts to make them very simple and easy to understand.
Your goal is to rewrite the provided input text for younger readers (Fernández Huerta 70–100; grade 5–7).
Use short sentences, simple words, and friendly tone. Avoid technical or complex expressions.
Keep all important factual details, but remove jargon.
Return only the rewritten text without commentary.
''',
"intermediate": '''
You are an assistant specialized in rewriting Spanish texts with medium readability.
Your task is to rewrite the provided input text for general or high‑school‑level readers (Fernández Huerta 50–70; grade 8–12).
Use clear and complete sentences, moderately complex vocabulary, and structured narration.
Retain all relevant medical or factual information, but phrase it in accessible language.
Return only the rewritten text with no explanations.
''',
"hard": '''
You are an assistant that rewrites Spanish medical texts with professional, technical precision.
Rewrite the following input text using specialized, academic terminology and information‑dense phrasing.
The output must target a Fernández Huerta readability index between 0 and 50 (university/professional level).
Use clinical vocabulary, formal register, and detailed description of pathophysiology, procedures, and findings.
Return only the rewritten text.
'''
}
# -------- New Part: Load keyword–definition dataset ----------
kw_file = "/home/mshahidul/readctrl/data/kyw_def_train/kyw_gen_gpt5.json"
with open(kw_file, "r", encoding="utf-8") as f:
definitions_data = json.load(f)
# Build quick lookup: id -> glossary text
def_map = {}
for obj in definitions_data:
cid = obj.get("id")
kwlist = obj.get("medical_keywords", [])
defs_str = ""
if kwlist:
defs_lines = [f"• {d['term']} — {d['definition']}" for d in kwlist]
defs_str = "Relevant medical definitions:\n" + "\n".join(defs_lines)
def_map[cid] = defs_str
# --------------------------------------------------------------
path = "/home/mshahidul/readctrl/data/testing_data/multiclinsum_test_es.json"
out_dir = "/home/mshahidul/readctrl/results/v3_context"
os.makedirs(out_dir, exist_ok=True)
if os.path.exists(model_name):
out_path = out_dir + f"/temp{temperature}_qwen3-14B_finetuned_with_defs.json"
else:
out_path = out_dir + f"/temp{temperature}_qwen3-14B_base_with_defs.json"
results, completed_keys = [], set()
if os.path.exists(out_path):
with open(out_path, "r", encoding="utf-8") as f:
results = json.load(f)
for r in results:
completed_keys.add(r["fulltext"])
# -------- Load main dataset -----------
with open(path, "r", encoding="utf-8") as f:
dataset = json.load(f)
dataset = dataset[0:50]
from unsloth import FastLanguageModel
import torch
model, tokenizer = FastLanguageModel.from_pretrained(
model_name=model_name,
max_seq_length=4092,
load_in_4bit=False,
load_in_8bit=False,
full_finetuning=False,
)
import tqdm
for item in tqdm.tqdm(dataset):
key = item["fulltext"]
if key in completed_keys:
continue
item_id = item["id"]
glossary = def_map.get(item_id, "") # retrieve glossary if exists
for band in ["easy", "intermediate", "hard"]:
# Append definitions below the case text
user_content = f"Input text:\n{item['fulltext'].strip()}"
if glossary:
user_content += "\n\n" + glossary
messages = [
{"role": "system", "content": prompts[band].strip()},
{"role": "user", "content": user_content}
]
text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
enable_thinking=False,
)
inputs = tokenizer(text, return_tensors="pt").to("cuda")
input_len = inputs.input_ids.shape[1]
length_factors = {"easy": 0.5, "intermediate": 0.8, "hard": 1.1}
max_new_tokens = int(min(1200, max(150, input_len * length_factors[band])))
output_ids = model.generate(
**inputs,
max_new_tokens=max_new_tokens,
temperature=temperature,
top_p=0.9,
top_k=45,
)
output_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
results.append({
"id": item_id,
"fulltext": item["fulltext"],
"band": band,
"lang": "es",
"synthetic_summary": output_text,
"definitions_used": bool(glossary) # track whether glossary applied
})
completed_keys.add(key)
if len(results) % 3 == 0:
with open(out_path, "w", encoding="utf-8") as f:
json.dump(results, f, ensure_ascii=False, indent=2)
with open(out_path, "w", encoding="utf-8") as f:
json.dump(results, f, ensure_ascii=False, indent=2)
from notifier import send_notification
send_notification(
"process-complete1507034",
f"Finished inference with model {model_name} at temperature {temperature}. Results saved to {out_path}",
title="Inference Complete",
priority="default",
tags="tada"
) |