File size: 3,672 Bytes
9ec3d1d | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 | # ABOUTME: Interactive CLI for testing the fine-tuned model
# ABOUTME: Enter diary text and get disease activity score predictions
import argparse
import torch
from peft import PeftModel
from transformers import AutoModelForCausalLM, AutoTokenizer
def load_model(
adapter_path: str,
base_model_name: str = "Qwen/Qwen2.5-3B-Instruct",
):
"""Load the fine-tuned model with merged LoRA adapter."""
print(f"Loading model: {base_model_name}")
if torch.backends.mps.is_available():
device = "mps"
model_dtype = torch.float16
elif torch.cuda.is_available():
device = "cuda"
model_dtype = torch.bfloat16
else:
device = "cpu"
model_dtype = torch.float32
print(f"Using device: {device}")
tokenizer = AutoTokenizer.from_pretrained(base_model_name)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
base_model = AutoModelForCausalLM.from_pretrained(
base_model_name,
dtype=model_dtype,
trust_remote_code=True,
)
print(f"Loading adapter: {adapter_path}")
model = PeftModel.from_pretrained(base_model, adapter_path)
model = model.merge_and_unload()
model = model.to(device)
model.eval()
print("Model ready.\n")
return model, tokenizer
def predict(model, tokenizer, diary_text: str) -> tuple[str | None, str]:
"""Run prediction on diary text, return (score, raw_output)."""
# Build the prompt in the same format as training data
user_content = f"Diary: {diary_text} What is the disease activity score for today?"
messages = [{"role": "user", "content": user_content}]
text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
)
inputs = tokenizer(text, return_tensors="pt").to(model.device)
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=10,
do_sample=False,
pad_token_id=tokenizer.pad_token_id,
)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
generated = response[len(text) :] if len(response) > len(text) else response
# Extract score (first digit 0-3)
score = None
for char in generated:
if char in "0123":
score = char
break
return score, generated.strip()
def main():
parser = argparse.ArgumentParser(description="Interactive model testing")
parser.add_argument(
"--adapter",
type=str,
required=True,
help="Path to the LoRA adapter directory",
)
parser.add_argument(
"--base-model",
type=str,
default="Qwen/Qwen2.5-3B-Instruct",
help="Base model name",
)
args = parser.parse_args()
model, tokenizer = load_model(args.adapter, args.base_model)
print("=" * 60)
print("Interactive Disease Activity Score Predictor")
print("=" * 60)
print("Enter diary text to get a prediction (0-3).")
print("Type 'quit' or 'exit' to stop.\n")
while True:
try:
diary_text = input("Diary> ").strip()
except (KeyboardInterrupt, EOFError):
print("\nExiting.")
break
if not diary_text:
continue
if diary_text.lower() in ("quit", "exit", "q"):
print("Exiting.")
break
score, raw = predict(model, tokenizer, diary_text)
if score is not None:
print(f" Score: {score}")
else:
print(f" Could not parse score from: {raw}")
print()
if __name__ == "__main__":
main()
|