| | |
| | |
| |
|
| | import argparse |
| |
|
| | import torch |
| | from peft import PeftModel |
| | from transformers import AutoModelForCausalLM, AutoTokenizer |
| |
|
| |
|
| | def load_model( |
| | adapter_path: str, |
| | base_model_name: str = "Qwen/Qwen2.5-3B-Instruct", |
| | ): |
| | """Load the fine-tuned model with merged LoRA adapter.""" |
| | print(f"Loading model: {base_model_name}") |
| |
|
| | if torch.backends.mps.is_available(): |
| | device = "mps" |
| | model_dtype = torch.float16 |
| | elif torch.cuda.is_available(): |
| | device = "cuda" |
| | model_dtype = torch.bfloat16 |
| | else: |
| | device = "cpu" |
| | model_dtype = torch.float32 |
| |
|
| | print(f"Using device: {device}") |
| |
|
| | tokenizer = AutoTokenizer.from_pretrained(base_model_name) |
| | if tokenizer.pad_token is None: |
| | tokenizer.pad_token = tokenizer.eos_token |
| |
|
| | base_model = AutoModelForCausalLM.from_pretrained( |
| | base_model_name, |
| | dtype=model_dtype, |
| | trust_remote_code=True, |
| | ) |
| |
|
| | print(f"Loading adapter: {adapter_path}") |
| | model = PeftModel.from_pretrained(base_model, adapter_path) |
| | model = model.merge_and_unload() |
| | model = model.to(device) |
| | model.eval() |
| |
|
| | print("Model ready.\n") |
| | return model, tokenizer |
| |
|
| |
|
| | def predict(model, tokenizer, diary_text: str) -> tuple[str | None, str]: |
| | """Run prediction on diary text, return (score, raw_output).""" |
| | |
| | user_content = f"Diary: {diary_text} What is the disease activity score for today?" |
| | messages = [{"role": "user", "content": user_content}] |
| |
|
| | text = tokenizer.apply_chat_template( |
| | messages, |
| | tokenize=False, |
| | add_generation_prompt=True, |
| | ) |
| |
|
| | inputs = tokenizer(text, return_tensors="pt").to(model.device) |
| |
|
| | with torch.no_grad(): |
| | outputs = model.generate( |
| | **inputs, |
| | max_new_tokens=10, |
| | do_sample=False, |
| | pad_token_id=tokenizer.pad_token_id, |
| | ) |
| |
|
| | response = tokenizer.decode(outputs[0], skip_special_tokens=True) |
| | generated = response[len(text) :] if len(response) > len(text) else response |
| |
|
| | |
| | score = None |
| | for char in generated: |
| | if char in "0123": |
| | score = char |
| | break |
| |
|
| | return score, generated.strip() |
| |
|
| |
|
| | def main(): |
| | parser = argparse.ArgumentParser(description="Interactive model testing") |
| | parser.add_argument( |
| | "--adapter", |
| | type=str, |
| | required=True, |
| | help="Path to the LoRA adapter directory", |
| | ) |
| | parser.add_argument( |
| | "--base-model", |
| | type=str, |
| | default="Qwen/Qwen2.5-3B-Instruct", |
| | help="Base model name", |
| | ) |
| | args = parser.parse_args() |
| |
|
| | model, tokenizer = load_model(args.adapter, args.base_model) |
| |
|
| | print("=" * 60) |
| | print("Interactive Disease Activity Score Predictor") |
| | print("=" * 60) |
| | print("Enter diary text to get a prediction (0-3).") |
| | print("Type 'quit' or 'exit' to stop.\n") |
| |
|
| | while True: |
| | try: |
| | diary_text = input("Diary> ").strip() |
| | except (KeyboardInterrupt, EOFError): |
| | print("\nExiting.") |
| | break |
| |
|
| | if not diary_text: |
| | continue |
| |
|
| | if diary_text.lower() in ("quit", "exit", "q"): |
| | print("Exiting.") |
| | break |
| |
|
| | score, raw = predict(model, tokenizer, diary_text) |
| |
|
| | if score is not None: |
| | print(f" Score: {score}") |
| | else: |
| | print(f" Could not parse score from: {raw}") |
| | print() |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|