| import os |
| import json |
| from datetime import datetime |
|
|
| import numpy as np |
| from datasets import Dataset |
| from openai import OpenAI |
| from transformers import AutoTokenizer |
| from unsloth.chat_templates import get_chat_template |
|
|
| |
| |
| |
| |
| VLLM_BASE_URL = os.getenv("VLLM_BASE_URL", "http://localhost:8040/v1") |
|
|
| |
| VLLM_MODEL_NAME = os.getenv( |
| "VLLM_MODEL_NAME", |
| "classifier", |
| ) |
|
|
| |
| OPENAI_API_KEY = os.getenv("OPENAI_API_KEY", "EMPTY") |
|
|
| |
| data_path = "/home/mshahidul/readctrl/code/text_classifier/bn/testing_bn_full.json" |
| test_size = 0.2 |
| seed = 42 |
| prompt_language = "en" |
|
|
| model_info_dir = "/home/mshahidul/readctrl/code/text_classifier/bn/model_info" |
| ablation_dir = "/home/mshahidul/readctrl/code/text_classifier/bn/ablation_studies" |
| os.makedirs(model_info_dir, exist_ok=True) |
| os.makedirs(ablation_dir, exist_ok=True) |
|
|
| |
| |
| |
| BASE_MODEL_FOR_TEMPLATE = "unsloth/gemma-3-4b-it" |
| tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_FOR_TEMPLATE) |
| tokenizer = get_chat_template(tokenizer, chat_template="gemma-3") |
|
|
| |
| |
| |
| def build_classification_user_prompt(fulltext, gen_text): |
| |
| if prompt_language == "en": |
| return ( |
| "You will be given a medical case description as reference (full text) and a generated text to classify. " |
| "Determine the patient's health literacy level based only on the generated text.\n\n" |
| f"Reference (full text):\n{fulltext}\n\n" |
| f"Generated text (to classify):\n{gen_text}\n\n" |
| "Reply with exactly one label from this set:\n" |
| "low_health_literacy, intermediate_health_literacy, proficient_health_literacy" |
| ) |
| |
| return ( |
| "আপনাকে রেফারেন্স হিসেবে মেডিকেল কেসের পূর্ণ বর্ণনা (reference full text) এবং মূলভাবে শ্রেণিবিন্যাস করার জন্য তৈরি করা টেক্সট (generated text) দেওয়া হবে। " |
| "শুধুমাত্র তৈরি করা টেক্সট (generated text)-এর উপর ভিত্তি করে রোগীর স্বাস্থ্যজ্ঞান (health literacy) কোন স্তরের তা নির্ধারণ করুন।\n\n" |
| f"Reference (full text):\n{fulltext}\n\n" |
| f"Generated text (যেটি শ্রেণিবিন্যাস করতে হবে):\n{gen_text}\n\n" |
| "শুধু নিচের সেট থেকে একটি লেবেল দিয়ে উত্তর দিন:\n" |
| "low_health_literacy, intermediate_health_literacy, proficient_health_literacy" |
| ) |
|
|
|
|
| def build_classification_examples(raw_records): |
| examples = [] |
| for record in raw_records: |
| fulltext = record.get("fulltext", "") |
| gen_text = record.get("gen_text", "") |
| label = (record.get("label") or "").strip() |
| if not label: |
| continue |
| user_prompt = build_classification_user_prompt(fulltext, gen_text) |
| examples.append( |
| { |
| "fulltext": fulltext, |
| "gen_text": gen_text, |
| "gold_label": label, |
| "user_prompt": user_prompt, |
| } |
| ) |
| return examples |
|
|
|
|
| |
| |
| |
| client = OpenAI( |
| base_url=VLLM_BASE_URL, |
| api_key=OPENAI_API_KEY, |
| ) |
|
|
|
|
| def vllm_generate_label(user_prompt: str, max_tokens: int = 32) -> str: |
| """Call vLLM endpoint using the same chat template as finetuning.""" |
| prompt = tokenizer.apply_chat_template( |
| [{"role": "user", "content": user_prompt}], |
| tokenize=False, |
| add_generation_prompt=True, |
| ) |
| |
| |
| |
| |
| stop_sequences = [tokenizer.eos_token, "<|endoftext|>", "\n", "<|im_end|>","<eos>","<end_of_turn>"] |
| |
| |
| response = client.completions.create( |
| model=VLLM_MODEL_NAME, |
| prompt=prompt, |
| temperature=0.0, |
| max_tokens=max_tokens, |
| stop=stop_sequences, |
| ) |
| |
| content = response.choices[0].text or "" |
| |
| |
| |
| |
| predicted_label = content.strip().split('\n')[0].strip() |
| |
| return predicted_label |
|
|
|
|
| |
| |
| |
| def load_test_split(): |
| with open(data_path, "r", encoding="utf-8") as f: |
| raw_data = json.load(f) |
|
|
| raw_dataset = Dataset.from_list(raw_data) |
| split_dataset = raw_dataset.train_test_split( |
| test_size=test_size, seed=seed, shuffle=True |
| ) |
| test_raw = split_dataset["test"] |
| return test_raw |
|
|
|
|
| |
| |
| |
| def evaluate_with_vllm(test_split): |
| examples = build_classification_examples(test_split) |
| results = [] |
| total = 0 |
| correct = 0 |
|
|
| for idx, ex in enumerate(examples): |
| fulltext = ex["fulltext"] |
| gen_text = ex["gen_text"] |
| gold_label = ex["gold_label"] |
| user_prompt = ex["user_prompt"] |
|
|
| try: |
| pred_label = vllm_generate_label(user_prompt) |
| except Exception as e: |
| pred_label = f"ERROR: {e}" |
|
|
| total += 1 |
| is_correct = pred_label == gold_label |
| if is_correct: |
| correct += 1 |
|
|
| results.append( |
| { |
| "sample_index": idx, |
| "fulltext": fulltext, |
| "gen_text": gen_text, |
| "gold_label": gold_label, |
| "predicted_label": pred_label, |
| "correct": is_correct, |
| } |
| ) |
|
|
| accuracy = correct / total if total else 0.0 |
| return results, accuracy |
|
|
|
|
| def main(): |
| timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") |
| model_tag = os.path.basename(str(VLLM_MODEL_NAME)).replace(".", "_") |
|
|
| test_raw = load_test_split() |
| results, accuracy = evaluate_with_vllm(test_raw) |
|
|
| metrics = { |
| "mode": "fulltext_gen_text_classification", |
| "model_name": VLLM_MODEL_NAME, |
| "dataset_path": data_path, |
| "prompt_language": prompt_language, |
| "seed": seed, |
| "test_size": test_size, |
| "examples_evaluated": len(results), |
| "accuracy": accuracy, |
| "timestamp": timestamp, |
| "inference_backend": "vllm_openai_server", |
| } |
|
|
| predictions_path = os.path.join( |
| model_info_dir, f"{model_tag}_vllm_test_inference_{timestamp}.json" |
| ) |
| accuracy_path = os.path.join( |
| ablation_dir, f"{model_tag}_vllm_classification_{timestamp}.json" |
| ) |
|
|
| with open(predictions_path, "w", encoding="utf-8") as f: |
| json.dump(results, f, ensure_ascii=False, indent=2) |
|
|
| with open(accuracy_path, "w", encoding="utf-8") as f: |
| json.dump(metrics, f, ensure_ascii=False, indent=2) |
|
|
| print(f"Saved vLLM test inference to: {predictions_path}") |
| print(f"Saved vLLM test accuracy to: {accuracy_path}") |
| print(f"Accuracy: {accuracy:.4f}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |