| |
| import torch |
| import os |
| from transformers import AutoTokenizer, AutoModelForSeq2SeqLM |
| from peft import PeftModel, LoraConfig, TaskType |
|
|
| def load_t5_qa_model(base_model_name, adapter_path, device="cpu"): |
| print("Loading T5 QA tokenizer...") |
| tokenizer = AutoTokenizer.from_pretrained(base_model_name) |
|
|
| print(f"Loading base T5 model: {base_model_name}...") |
| base_model = AutoModelForSeq2SeqLM.from_pretrained(base_model_name) |
|
|
| print(f"Loading PEFT adapter from {adapter_path} and merging...") |
| if not os.path.exists(adapter_path): |
| raise FileNotFoundError(f"Adapter path not found: {adapter_path}") |
|
|
| peft_config = LoraConfig( |
| task_type=TaskType.SEQ_2_SEQ_LM, |
| inference_mode=False, |
| r=8, |
| lora_alpha=32, |
| lora_dropout=0.1 |
| ) |
|
|
| model = PeftModel.from_pretrained(base_model, adapter_path) |
| model = model.merge_and_unload() |
| model.eval() |
| model.to(device) |
|
|
| print(f"Model loaded and moved to {device}.") |
| return tokenizer, model |
|
|