Upload jepa_retrieval.py
Browse files- jepa_retrieval.py +61 -9
jepa_retrieval.py
CHANGED
|
@@ -942,7 +942,7 @@ def setup_model_and_tokenizer(
|
|
| 942 |
# Load model
|
| 943 |
model = AutoModelForCausalLM.from_pretrained(
|
| 944 |
model_name,
|
| 945 |
-
|
| 946 |
device_map="auto",
|
| 947 |
trust_remote_code=True,
|
| 948 |
use_cache=False,
|
|
@@ -1074,15 +1074,67 @@ def run_inference(args):
|
|
| 1074 |
|
| 1075 |
# Load model
|
| 1076 |
print("\n1. Loading model...")
|
| 1077 |
-
tokenizer = AutoTokenizer.from_pretrained(args.model_name)
|
| 1078 |
-
model = AutoModelForCausalLM.from_pretrained(
|
| 1079 |
-
args.model_name,
|
| 1080 |
-
torch_dtype=torch.bfloat16,
|
| 1081 |
-
device_map="auto",
|
| 1082 |
-
)
|
| 1083 |
|
| 1084 |
-
if
|
| 1085 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1086 |
|
| 1087 |
# Load corpus
|
| 1088 |
print("\n2. Loading corpus...")
|
|
|
|
| 942 |
# Load model
|
| 943 |
model = AutoModelForCausalLM.from_pretrained(
|
| 944 |
model_name,
|
| 945 |
+
dtype=torch.bfloat16,
|
| 946 |
device_map="auto",
|
| 947 |
trust_remote_code=True,
|
| 948 |
use_cache=False,
|
|
|
|
| 1074 |
|
| 1075 |
# Load model
|
| 1076 |
print("\n1. Loading model...")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1077 |
|
| 1078 |
+
# Check if this is a LoRA/PEFT checkpoint (has adapter_config.json)
|
| 1079 |
+
model_path = args.model_name
|
| 1080 |
+
adapter_config_path = os.path.join(model_path, "adapter_config.json")
|
| 1081 |
+
is_peft_checkpoint = os.path.exists(adapter_config_path)
|
| 1082 |
+
|
| 1083 |
+
if is_peft_checkpoint:
|
| 1084 |
+
print(f"Detected PEFT/LoRA checkpoint at {model_path}")
|
| 1085 |
+
|
| 1086 |
+
# Read adapter config to get base model name
|
| 1087 |
+
with open(adapter_config_path, 'r') as f:
|
| 1088 |
+
adapter_config = json.load(f)
|
| 1089 |
+
base_model_name = adapter_config.get("base_model_name_or_path", "meta-llama/Llama-3.2-1B-Instruct")
|
| 1090 |
+
print(f"Base model: {base_model_name}")
|
| 1091 |
+
|
| 1092 |
+
# Load tokenizer from checkpoint (has special tokens)
|
| 1093 |
+
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
| 1094 |
+
|
| 1095 |
+
if tokenizer.pad_token is None:
|
| 1096 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 1097 |
+
|
| 1098 |
+
# Load base model
|
| 1099 |
+
from peft import PeftModel
|
| 1100 |
+
|
| 1101 |
+
base_model = AutoModelForCausalLM.from_pretrained(
|
| 1102 |
+
base_model_name,
|
| 1103 |
+
torch_dtype=torch.bfloat16,
|
| 1104 |
+
device_map="auto",
|
| 1105 |
+
local_files_only=False,
|
| 1106 |
+
)
|
| 1107 |
+
|
| 1108 |
+
# Resize embeddings to match tokenizer (with special tokens)
|
| 1109 |
+
base_model.resize_token_embeddings(len(tokenizer))
|
| 1110 |
+
|
| 1111 |
+
# Load PEFT adapter
|
| 1112 |
+
model = PeftModel.from_pretrained(base_model, model_path)
|
| 1113 |
+
model = model.merge_and_unload() # Merge for faster inference
|
| 1114 |
+
print("Loaded and merged PEFT adapter")
|
| 1115 |
+
|
| 1116 |
+
else:
|
| 1117 |
+
# Regular model (merged or base)
|
| 1118 |
+
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
| 1119 |
+
|
| 1120 |
+
# Add special tokens (must match training)
|
| 1121 |
+
special_tokens = ["[QUERY]", "[DOC]", "[SUPPORT]", "[ANSWER]"]
|
| 1122 |
+
new_tokens = [t for t in special_tokens if t not in tokenizer.get_vocab()]
|
| 1123 |
+
if new_tokens:
|
| 1124 |
+
tokenizer.add_special_tokens({"additional_special_tokens": new_tokens})
|
| 1125 |
+
print(f"Added {len(new_tokens)} special tokens: {new_tokens}")
|
| 1126 |
+
|
| 1127 |
+
if tokenizer.pad_token is None:
|
| 1128 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 1129 |
+
|
| 1130 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 1131 |
+
model_path,
|
| 1132 |
+
torch_dtype=torch.bfloat16,
|
| 1133 |
+
device_map="auto",
|
| 1134 |
+
)
|
| 1135 |
+
|
| 1136 |
+
# Always resize embeddings to match tokenizer
|
| 1137 |
+
model.resize_token_embeddings(len(tokenizer))
|
| 1138 |
|
| 1139 |
# Load corpus
|
| 1140 |
print("\n2. Loading corpus...")
|