# ABOUTME: Interactive CLI for testing the fine-tuned model # ABOUTME: Enter diary text and get disease activity score predictions import argparse import torch from peft import PeftModel from transformers import AutoModelForCausalLM, AutoTokenizer def load_model( adapter_path: str, base_model_name: str = "Qwen/Qwen2.5-3B-Instruct", ): """Load the fine-tuned model with merged LoRA adapter.""" print(f"Loading model: {base_model_name}") if torch.backends.mps.is_available(): device = "mps" model_dtype = torch.float16 elif torch.cuda.is_available(): device = "cuda" model_dtype = torch.bfloat16 else: device = "cpu" model_dtype = torch.float32 print(f"Using device: {device}") tokenizer = AutoTokenizer.from_pretrained(base_model_name) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token base_model = AutoModelForCausalLM.from_pretrained( base_model_name, dtype=model_dtype, trust_remote_code=True, ) print(f"Loading adapter: {adapter_path}") model = PeftModel.from_pretrained(base_model, adapter_path) model = model.merge_and_unload() model = model.to(device) model.eval() print("Model ready.\n") return model, tokenizer def predict(model, tokenizer, diary_text: str) -> tuple[str | None, str]: """Run prediction on diary text, return (score, raw_output).""" # Build the prompt in the same format as training data user_content = f"Diary: {diary_text} What is the disease activity score for today?" messages = [{"role": "user", "content": user_content}] text = tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True, ) inputs = tokenizer(text, return_tensors="pt").to(model.device) with torch.no_grad(): outputs = model.generate( **inputs, max_new_tokens=10, do_sample=False, pad_token_id=tokenizer.pad_token_id, ) response = tokenizer.decode(outputs[0], skip_special_tokens=True) generated = response[len(text) :] if len(response) > len(text) else response # Extract score (first digit 0-3) score = None for char in generated: if char in "0123": score = char break return score, generated.strip() def main(): parser = argparse.ArgumentParser(description="Interactive model testing") parser.add_argument( "--adapter", type=str, required=True, help="Path to the LoRA adapter directory", ) parser.add_argument( "--base-model", type=str, default="Qwen/Qwen2.5-3B-Instruct", help="Base model name", ) args = parser.parse_args() model, tokenizer = load_model(args.adapter, args.base_model) print("=" * 60) print("Interactive Disease Activity Score Predictor") print("=" * 60) print("Enter diary text to get a prediction (0-3).") print("Type 'quit' or 'exit' to stop.\n") while True: try: diary_text = input("Diary> ").strip() except (KeyboardInterrupt, EOFError): print("\nExiting.") break if not diary_text: continue if diary_text.lower() in ("quit", "exit", "q"): print("Exiting.") break score, raw = predict(model, tokenizer, diary_text) if score is not None: print(f" Score: {score}") else: print(f" Could not parse score from: {raw}") print() if __name__ == "__main__": main()