RAGQA / model.py
hmm183's picture
Create model.py
e6145fe verified
# model.py
import torch
import os
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from peft import PeftModel, LoraConfig, TaskType
def load_t5_qa_model(base_model_name, adapter_path, device="cpu"):
print("Loading T5 QA tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(base_model_name)
print(f"Loading base T5 model: {base_model_name}...")
base_model = AutoModelForSeq2SeqLM.from_pretrained(base_model_name)
print(f"Loading PEFT adapter from {adapter_path} and merging...")
if not os.path.exists(adapter_path):
raise FileNotFoundError(f"Adapter path not found: {adapter_path}")
peft_config = LoraConfig(
task_type=TaskType.SEQ_2_SEQ_LM,
inference_mode=False,
r=8,
lora_alpha=32,
lora_dropout=0.1
)
model = PeftModel.from_pretrained(base_model, adapter_path)
model = model.merge_and_unload()
model.eval()
model.to(device)
print(f"Model loaded and moved to {device}.")
return tokenizer, model