piliguori commited on
Commit
d6272b6
·
verified ·
1 Parent(s): 23479ca

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -7
app.py CHANGED
@@ -5,6 +5,8 @@ import autopep8
5
  import glob
6
  import re
7
  import os
 
 
8
 
9
  # ==========================
10
  # Utility functions
@@ -188,17 +190,30 @@ def find_example_files():
188
  # Load model from HF Hub
189
  # ==========================
190
 
191
- BASE_MODEL_ID = "Salesforce/codet5p-770m"
192
- FINETUNED_MODEL_ID = "OSS-Forge/codet5p-770m-pyresbugs"
 
193
 
194
  print(f"Loading tokenizer from base model: {BASE_MODEL_ID}")
195
  tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_ID)
196
 
197
- print(f"Loading fine-tuned model weights from: {FINETUNED_MODEL_ID}")
198
- model = AutoModelForSeq2SeqLM.from_pretrained(
199
- FINETUNED_MODEL_ID,
200
- ignore_mismatched_sizes=True,
201
- )
 
 
 
 
 
 
 
 
 
 
 
 
202
 
203
 
204
 
 
5
  import glob
6
  import re
7
  import os
8
+ from huggingface_hub import hf_hub_download
9
+
10
 
11
  # ==========================
12
  # Utility functions
 
190
  # Load model from HF Hub
191
  # ==========================
192
 
193
+ BASE_MODEL_ID = "Salesforce/codet5p-770m"
194
+ FINETUNED_REPO_ID = "OSS-forge/codet5p-770m-pyresbugs"
195
+ FINETUNED_FILENAME = "pytorch_model.bin"
196
 
197
  print(f"Loading tokenizer from base model: {BASE_MODEL_ID}")
198
  tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_ID)
199
 
200
+ print(f"Loading base model: {BASE_MODEL_ID}")
201
+ model = AutoModelForSeq2SeqLM.from_pretrained(BASE_MODEL_ID)
202
+
203
+ print(f"Downloading fine-tuned weights from repo: {FINETUNED_REPO_ID}")
204
+ ckpt_path = hf_hub_download(FINETUNED_REPO_ID, FINETUNED_FILENAME)
205
+
206
+ print(f"Loading state_dict from: {ckpt_path}")
207
+ state_dict = torch.load(ckpt_path, map_location=torch.device("cpu"))
208
+
209
+ if "model_state_dict" in state_dict:
210
+ state_dict = state_dict["model_state_dict"]
211
+
212
+ missing, unexpected = model.load_state_dict(state_dict, strict=False)
213
+ print("Loaded fine-tuned weights.")
214
+ print("Missing keys:", len(missing), "Unexpected keys:", len(unexpected))
215
+
216
+ model.eval()
217
 
218
 
219