import os import json from datetime import datetime import numpy as np from datasets import Dataset from openai import OpenAI from transformers import AutoTokenizer from unsloth.chat_templates import get_chat_template # ----------------------------- # Configuration # ----------------------------- # vLLM server (OpenAI-compatible) URL, e.g. "http://localhost:8000/v1" VLLM_BASE_URL = os.getenv("VLLM_BASE_URL", "http://localhost:8040/v1") # Model name as seen by vLLM server (can be HF repo id or local path) VLLM_MODEL_NAME = os.getenv( "VLLM_MODEL_NAME", "classifier", # adjust if needed ) # Dummy key is fine for vLLM if auth is disabled OPENAI_API_KEY = os.getenv("OPENAI_API_KEY", "EMPTY") # Data and output paths (mirrors finetune script) data_path = "/home/mshahidul/readctrl/code/text_classifier/bn/testing_bn_full.json" test_size = 0.2 seed = 42 prompt_language = "en" # "bn" or "en" model_info_dir = "/home/mshahidul/readctrl/code/text_classifier/bn/model_info" ablation_dir = "/home/mshahidul/readctrl/code/text_classifier/bn/ablation_studies" os.makedirs(model_info_dir, exist_ok=True) os.makedirs(ablation_dir, exist_ok=True) # ----------------------------- # Chat template / tokenizer (match finetune script) # ----------------------------- BASE_MODEL_FOR_TEMPLATE = "unsloth/gemma-3-4b-it" tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_FOR_TEMPLATE) tokenizer = get_chat_template(tokenizer, chat_template="gemma-3") # ----------------------------- # Prompt construction (copied from finetune script) # ----------------------------- def build_classification_user_prompt(fulltext, gen_text): # Input: fulltext (reference) + gen_text (main text to classify), Output: label if prompt_language == "en": return ( "You will be given a medical case description as reference (full text) and a generated text to classify. " "Determine the patient's health literacy level based only on the generated text.\n\n" f"Reference (full text):\n{fulltext}\n\n" f"Generated text (to classify):\n{gen_text}\n\n" "Reply with exactly one label from this set:\n" "low_health_literacy, intermediate_health_literacy, proficient_health_literacy" ) # Bangla (default) return ( "আপনাকে রেফারেন্স হিসেবে মেডিকেল কেসের পূর্ণ বর্ণনা (reference full text) এবং মূলভাবে শ্রেণিবিন্যাস করার জন্য তৈরি করা টেক্সট (generated text) দেওয়া হবে। " "শুধুমাত্র তৈরি করা টেক্সট (generated text)-এর উপর ভিত্তি করে রোগীর স্বাস্থ্যজ্ঞান (health literacy) কোন স্তরের তা নির্ধারণ করুন।\n\n" f"Reference (full text):\n{fulltext}\n\n" f"Generated text (যেটি শ্রেণিবিন্যাস করতে হবে):\n{gen_text}\n\n" "শুধু নিচের সেট থেকে একটি লেবেল দিয়ে উত্তর দিন:\n" "low_health_literacy, intermediate_health_literacy, proficient_health_literacy" ) def build_classification_examples(raw_records): examples = [] for record in raw_records: fulltext = record.get("fulltext", "") gen_text = record.get("gen_text", "") label = (record.get("label") or "").strip() if not label: continue user_prompt = build_classification_user_prompt(fulltext, gen_text) examples.append( { "fulltext": fulltext, "gen_text": gen_text, "gold_label": label, "user_prompt": user_prompt, } ) return examples # ----------------------------- # vLLM client # ----------------------------- client = OpenAI( base_url=VLLM_BASE_URL, api_key=OPENAI_API_KEY, ) def vllm_generate_label(user_prompt: str, max_tokens: int = 32) -> str: """Call vLLM endpoint using the same chat template as finetuning.""" prompt = tokenizer.apply_chat_template( [{"role": "user", "content": user_prompt}], tokenize=False, add_generation_prompt=True, ) # 1. Define stop sequences. # For Gemma 3, common ones are "<|endoftext|>", "<|file_separator|>", or "\n" # Since your labels are single words, stopping at a newline is safest. stop_sequences = [tokenizer.eos_token, "<|endoftext|>", "\n", "<|im_end|>","",""] # print(stop_sequences,"stop sequences") response = client.completions.create( model=VLLM_MODEL_NAME, prompt=prompt, temperature=0.0, max_tokens=max_tokens, stop=stop_sequences, # <--- CRITICAL FIX ) content = response.choices[0].text or "" # import ipdb; ipdb.set_trace() # 2. Clean up: split by lines and take the first non-empty line # This handles cases where the model might still return "label\n\n" predicted_label = content.strip().split('\n')[0].strip() return predicted_label # ----------------------------- # Data loading & test split # ----------------------------- def load_test_split(): with open(data_path, "r", encoding="utf-8") as f: raw_data = json.load(f) raw_dataset = Dataset.from_list(raw_data) split_dataset = raw_dataset.train_test_split( test_size=test_size, seed=seed, shuffle=True ) test_raw = split_dataset["test"] return test_raw # ----------------------------- # Evaluation # ----------------------------- def evaluate_with_vllm(test_split): examples = build_classification_examples(test_split) results = [] total = 0 correct = 0 for idx, ex in enumerate(examples): fulltext = ex["fulltext"] gen_text = ex["gen_text"] gold_label = ex["gold_label"] user_prompt = ex["user_prompt"] try: pred_label = vllm_generate_label(user_prompt) except Exception as e: pred_label = f"ERROR: {e}" total += 1 is_correct = pred_label == gold_label if is_correct: correct += 1 results.append( { "sample_index": idx, "fulltext": fulltext, "gen_text": gen_text, "gold_label": gold_label, "predicted_label": pred_label, "correct": is_correct, } ) accuracy = correct / total if total else 0.0 return results, accuracy def main(): timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") model_tag = os.path.basename(str(VLLM_MODEL_NAME)).replace(".", "_") test_raw = load_test_split() results, accuracy = evaluate_with_vllm(test_raw) metrics = { "mode": "fulltext_gen_text_classification", "model_name": VLLM_MODEL_NAME, "dataset_path": data_path, "prompt_language": prompt_language, "seed": seed, "test_size": test_size, "examples_evaluated": len(results), "accuracy": accuracy, "timestamp": timestamp, "inference_backend": "vllm_openai_server", } predictions_path = os.path.join( model_info_dir, f"{model_tag}_vllm_test_inference_{timestamp}.json" ) accuracy_path = os.path.join( ablation_dir, f"{model_tag}_vllm_classification_{timestamp}.json" ) with open(predictions_path, "w", encoding="utf-8") as f: json.dump(results, f, ensure_ascii=False, indent=2) with open(accuracy_path, "w", encoding="utf-8") as f: json.dump(metrics, f, ensure_ascii=False, indent=2) print(f"Saved vLLM test inference to: {predictions_path}") print(f"Saved vLLM test accuracy to: {accuracy_path}") print(f"Accuracy: {accuracy:.4f}") if __name__ == "__main__": main()