ff6347's picture
Upload folder using huggingface_hub
9ec3d1d verified
# ABOUTME: Interactive CLI for testing the fine-tuned model
# ABOUTME: Enter diary text and get disease activity score predictions
import argparse
import torch
from peft import PeftModel
from transformers import AutoModelForCausalLM, AutoTokenizer
def load_model(
adapter_path: str,
base_model_name: str = "Qwen/Qwen2.5-3B-Instruct",
):
"""Load the fine-tuned model with merged LoRA adapter."""
print(f"Loading model: {base_model_name}")
if torch.backends.mps.is_available():
device = "mps"
model_dtype = torch.float16
elif torch.cuda.is_available():
device = "cuda"
model_dtype = torch.bfloat16
else:
device = "cpu"
model_dtype = torch.float32
print(f"Using device: {device}")
tokenizer = AutoTokenizer.from_pretrained(base_model_name)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
base_model = AutoModelForCausalLM.from_pretrained(
base_model_name,
dtype=model_dtype,
trust_remote_code=True,
)
print(f"Loading adapter: {adapter_path}")
model = PeftModel.from_pretrained(base_model, adapter_path)
model = model.merge_and_unload()
model = model.to(device)
model.eval()
print("Model ready.\n")
return model, tokenizer
def predict(model, tokenizer, diary_text: str) -> tuple[str | None, str]:
"""Run prediction on diary text, return (score, raw_output)."""
# Build the prompt in the same format as training data
user_content = f"Diary: {diary_text} What is the disease activity score for today?"
messages = [{"role": "user", "content": user_content}]
text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
)
inputs = tokenizer(text, return_tensors="pt").to(model.device)
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=10,
do_sample=False,
pad_token_id=tokenizer.pad_token_id,
)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
generated = response[len(text) :] if len(response) > len(text) else response
# Extract score (first digit 0-3)
score = None
for char in generated:
if char in "0123":
score = char
break
return score, generated.strip()
def main():
parser = argparse.ArgumentParser(description="Interactive model testing")
parser.add_argument(
"--adapter",
type=str,
required=True,
help="Path to the LoRA adapter directory",
)
parser.add_argument(
"--base-model",
type=str,
default="Qwen/Qwen2.5-3B-Instruct",
help="Base model name",
)
args = parser.parse_args()
model, tokenizer = load_model(args.adapter, args.base_model)
print("=" * 60)
print("Interactive Disease Activity Score Predictor")
print("=" * 60)
print("Enter diary text to get a prediction (0-3).")
print("Type 'quit' or 'exit' to stop.\n")
while True:
try:
diary_text = input("Diary> ").strip()
except (KeyboardInterrupt, EOFError):
print("\nExiting.")
break
if not diary_text:
continue
if diary_text.lower() in ("quit", "exit", "q"):
print("Exiting.")
break
score, raw = predict(model, tokenizer, diary_text)
if score is not None:
print(f" Score: {score}")
else:
print(f" Could not parse score from: {raw}")
print()
if __name__ == "__main__":
main()