Update app.py
Browse files
app.py
CHANGED
|
@@ -191,27 +191,29 @@ def find_example_files():
|
|
| 191 |
# ==========================
|
| 192 |
|
| 193 |
BASE_MODEL_ID = "Salesforce/codet5p-770m"
|
| 194 |
-
FINETUNED_REPO_ID = "OSS-forge/codet5p-770m-pyresbugs"
|
| 195 |
FINETUNED_FILENAME = "pytorch_model.bin"
|
| 196 |
|
|
|
|
|
|
|
| 197 |
print(f"Loading tokenizer from base model: {BASE_MODEL_ID}")
|
| 198 |
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_ID)
|
| 199 |
|
| 200 |
print(f"Loading base model: {BASE_MODEL_ID}")
|
| 201 |
model = AutoModelForSeq2SeqLM.from_pretrained(BASE_MODEL_ID)
|
|
|
|
| 202 |
|
| 203 |
print(f"Downloading fine-tuned weights from repo: {FINETUNED_REPO_ID}")
|
| 204 |
ckpt_path = hf_hub_download(FINETUNED_REPO_ID, FINETUNED_FILENAME)
|
| 205 |
|
| 206 |
print(f"Loading state_dict from: {ckpt_path}")
|
| 207 |
-
state_dict = torch.load(ckpt_path, map_location=
|
| 208 |
|
| 209 |
if "model_state_dict" in state_dict:
|
| 210 |
state_dict = state_dict["model_state_dict"]
|
| 211 |
|
| 212 |
missing, unexpected = model.load_state_dict(state_dict, strict=False)
|
| 213 |
-
print("Loaded fine-tuned weights.")
|
| 214 |
-
print("Missing keys:", len(missing), "Unexpected keys:", len(unexpected))
|
| 215 |
|
| 216 |
model.eval()
|
| 217 |
|
|
|
|
| 191 |
# ==========================
|
| 192 |
|
| 193 |
BASE_MODEL_ID = "Salesforce/codet5p-770m"
|
| 194 |
+
FINETUNED_REPO_ID = "OSS-forge/codet5p-770m-pyresbugs"
|
| 195 |
FINETUNED_FILENAME = "pytorch_model.bin"
|
| 196 |
|
| 197 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 198 |
+
|
| 199 |
print(f"Loading tokenizer from base model: {BASE_MODEL_ID}")
|
| 200 |
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_ID)
|
| 201 |
|
| 202 |
print(f"Loading base model: {BASE_MODEL_ID}")
|
| 203 |
model = AutoModelForSeq2SeqLM.from_pretrained(BASE_MODEL_ID)
|
| 204 |
+
model.to(device)
|
| 205 |
|
| 206 |
print(f"Downloading fine-tuned weights from repo: {FINETUNED_REPO_ID}")
|
| 207 |
ckpt_path = hf_hub_download(FINETUNED_REPO_ID, FINETUNED_FILENAME)
|
| 208 |
|
| 209 |
print(f"Loading state_dict from: {ckpt_path}")
|
| 210 |
+
state_dict = torch.load(ckpt_path, map_location="cpu")
|
| 211 |
|
| 212 |
if "model_state_dict" in state_dict:
|
| 213 |
state_dict = state_dict["model_state_dict"]
|
| 214 |
|
| 215 |
missing, unexpected = model.load_state_dict(state_dict, strict=False)
|
| 216 |
+
print(f"Loaded fine-tuned weights. Missing keys: {len(missing)}, unexpected keys: {len(unexpected)}")
|
|
|
|
| 217 |
|
| 218 |
model.eval()
|
| 219 |
|