piliguori commited on
Commit
fbf65ef
·
verified ·
1 Parent(s): e4047a3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -4
app.py CHANGED
@@ -191,27 +191,29 @@ def find_example_files():
191
  # ==========================
192
 
193
  BASE_MODEL_ID = "Salesforce/codet5p-770m"
194
- FINETUNED_REPO_ID = "OSS-forge/codet5p-770m-pyresbugs"
195
  FINETUNED_FILENAME = "pytorch_model.bin"
196
 
 
 
197
  print(f"Loading tokenizer from base model: {BASE_MODEL_ID}")
198
  tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_ID)
199
 
200
  print(f"Loading base model: {BASE_MODEL_ID}")
201
  model = AutoModelForSeq2SeqLM.from_pretrained(BASE_MODEL_ID)
 
202
 
203
  print(f"Downloading fine-tuned weights from repo: {FINETUNED_REPO_ID}")
204
  ckpt_path = hf_hub_download(FINETUNED_REPO_ID, FINETUNED_FILENAME)
205
 
206
  print(f"Loading state_dict from: {ckpt_path}")
207
- state_dict = torch.load(ckpt_path, map_location=torch.device("cpu"))
208
 
209
  if "model_state_dict" in state_dict:
210
  state_dict = state_dict["model_state_dict"]
211
 
212
  missing, unexpected = model.load_state_dict(state_dict, strict=False)
213
- print("Loaded fine-tuned weights.")
214
- print("Missing keys:", len(missing), "Unexpected keys:", len(unexpected))
215
 
216
  model.eval()
217
 
 
191
  # ==========================
192
 
193
  BASE_MODEL_ID = "Salesforce/codet5p-770m"
194
+ FINETUNED_REPO_ID = "OSS-forge/codet5p-770m-pyresbugs"
195
  FINETUNED_FILENAME = "pytorch_model.bin"
196
 
197
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
198
+
199
  print(f"Loading tokenizer from base model: {BASE_MODEL_ID}")
200
  tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_ID)
201
 
202
  print(f"Loading base model: {BASE_MODEL_ID}")
203
  model = AutoModelForSeq2SeqLM.from_pretrained(BASE_MODEL_ID)
204
+ model.to(device)
205
 
206
  print(f"Downloading fine-tuned weights from repo: {FINETUNED_REPO_ID}")
207
  ckpt_path = hf_hub_download(FINETUNED_REPO_ID, FINETUNED_FILENAME)
208
 
209
  print(f"Loading state_dict from: {ckpt_path}")
210
+ state_dict = torch.load(ckpt_path, map_location="cpu")
211
 
212
  if "model_state_dict" in state_dict:
213
  state_dict = state_dict["model_state_dict"]
214
 
215
  missing, unexpected = model.load_state_dict(state_dict, strict=False)
216
+ print(f"Loaded fine-tuned weights. Missing keys: {len(missing)}, unexpected keys: {len(unexpected)}")
 
217
 
218
  model.eval()
219