AnKhanh commited on
Commit
c0358fb
·
verified ·
1 Parent(s): e8b74fb

Upload jepa_retrieval.py

Browse files
Files changed (1) hide show
  1. jepa_retrieval.py +61 -9
jepa_retrieval.py CHANGED
@@ -942,7 +942,7 @@ def setup_model_and_tokenizer(
942
  # Load model
943
  model = AutoModelForCausalLM.from_pretrained(
944
  model_name,
945
- torch_dtype=torch.bfloat16,
946
  device_map="auto",
947
  trust_remote_code=True,
948
  use_cache=False,
@@ -1074,15 +1074,67 @@ def run_inference(args):
1074
 
1075
  # Load model
1076
  print("\n1. Loading model...")
1077
- tokenizer = AutoTokenizer.from_pretrained(args.model_name)
1078
- model = AutoModelForCausalLM.from_pretrained(
1079
- args.model_name,
1080
- torch_dtype=torch.bfloat16,
1081
- device_map="auto",
1082
- )
1083
 
1084
- if tokenizer.pad_token is None:
1085
- tokenizer.pad_token = tokenizer.eos_token
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1086
 
1087
  # Load corpus
1088
  print("\n2. Loading corpus...")
 
942
  # Load model
943
  model = AutoModelForCausalLM.from_pretrained(
944
  model_name,
945
+ dtype=torch.bfloat16,
946
  device_map="auto",
947
  trust_remote_code=True,
948
  use_cache=False,
 
1074
 
1075
  # Load model
1076
  print("\n1. Loading model...")
 
 
 
 
 
 
1077
 
1078
+ # Check if this is a LoRA/PEFT checkpoint (has adapter_config.json)
1079
+ model_path = args.model_name
1080
+ adapter_config_path = os.path.join(model_path, "adapter_config.json")
1081
+ is_peft_checkpoint = os.path.exists(adapter_config_path)
1082
+
1083
+ if is_peft_checkpoint:
1084
+ print(f"Detected PEFT/LoRA checkpoint at {model_path}")
1085
+
1086
+ # Read adapter config to get base model name
1087
+ with open(adapter_config_path, 'r') as f:
1088
+ adapter_config = json.load(f)
1089
+ base_model_name = adapter_config.get("base_model_name_or_path", "meta-llama/Llama-3.2-1B-Instruct")
1090
+ print(f"Base model: {base_model_name}")
1091
+
1092
+ # Load tokenizer from checkpoint (has special tokens)
1093
+ tokenizer = AutoTokenizer.from_pretrained(model_path)
1094
+
1095
+ if tokenizer.pad_token is None:
1096
+ tokenizer.pad_token = tokenizer.eos_token
1097
+
1098
+ # Load base model
1099
+ from peft import PeftModel
1100
+
1101
+ base_model = AutoModelForCausalLM.from_pretrained(
1102
+ base_model_name,
1103
+ torch_dtype=torch.bfloat16,
1104
+ device_map="auto",
1105
+ local_files_only=False,
1106
+ )
1107
+
1108
+ # Resize embeddings to match tokenizer (with special tokens)
1109
+ base_model.resize_token_embeddings(len(tokenizer))
1110
+
1111
+ # Load PEFT adapter
1112
+ model = PeftModel.from_pretrained(base_model, model_path)
1113
+ model = model.merge_and_unload() # Merge for faster inference
1114
+ print("Loaded and merged PEFT adapter")
1115
+
1116
+ else:
1117
+ # Regular model (merged or base)
1118
+ tokenizer = AutoTokenizer.from_pretrained(model_path)
1119
+
1120
+ # Add special tokens (must match training)
1121
+ special_tokens = ["[QUERY]", "[DOC]", "[SUPPORT]", "[ANSWER]"]
1122
+ new_tokens = [t for t in special_tokens if t not in tokenizer.get_vocab()]
1123
+ if new_tokens:
1124
+ tokenizer.add_special_tokens({"additional_special_tokens": new_tokens})
1125
+ print(f"Added {len(new_tokens)} special tokens: {new_tokens}")
1126
+
1127
+ if tokenizer.pad_token is None:
1128
+ tokenizer.pad_token = tokenizer.eos_token
1129
+
1130
+ model = AutoModelForCausalLM.from_pretrained(
1131
+ model_path,
1132
+ torch_dtype=torch.bfloat16,
1133
+ device_map="auto",
1134
+ )
1135
+
1136
+ # Always resize embeddings to match tokenizer
1137
+ model.resize_token_embeddings(len(tokenizer))
1138
 
1139
  # Load corpus
1140
  print("\n2. Loading corpus...")