Spaces:
Runtime error
Runtime error
Upload folder using huggingface_hub
Browse files
app.py
CHANGED
|
@@ -78,10 +78,12 @@ def model_worker(
|
|
| 78 |
}[args.dtype]
|
| 79 |
with default_tensor_type(dtype=target_dtype, device="cuda"):
|
| 80 |
model = MetaModel(args.llama_type, args.llama_config, tokenizer_path=args.tokenizer_path)
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
|
|
|
|
|
|
| 85 |
model.cuda()
|
| 86 |
model.eval()
|
| 87 |
print(f"Model = {str(model)}")
|
|
@@ -242,7 +244,10 @@ class DemoConfig:
|
|
| 242 |
llama_config = "config/llama2/7B.json"
|
| 243 |
model_max_seq_len = 2048
|
| 244 |
# pretrained_path = "weights/7B_2048/consolidated.00-of-01.pth"
|
| 245 |
-
pretrained_path = hf_hub_download(repo_id="csuhan/OneLLM-7B", filename="consolidated.00-of-01.pth")
|
|
|
|
|
|
|
|
|
|
| 246 |
master_port = 23861
|
| 247 |
master_addr = "127.0.0.1"
|
| 248 |
dtype = "fp16"
|
|
|
|
| 78 |
}[args.dtype]
|
| 79 |
with default_tensor_type(dtype=target_dtype, device="cuda"):
|
| 80 |
model = MetaModel(args.llama_type, args.llama_config, tokenizer_path=args.tokenizer_path)
|
| 81 |
+
for ckpt_id in args.num_ckpts:
|
| 82 |
+
ckpt_path = hf_hub_download(repo_id=args.pretrained_path, filename=args.ckpt_format.format(str(ckpt_id)))
|
| 83 |
+
print(f"Loading pretrained weights {ckpt_path}")
|
| 84 |
+
checkpoint = torch.load(ckpt_path, map_location='cpu')
|
| 85 |
+
msg = model.load_state_dict(checkpoint, strict=False)
|
| 86 |
+
# print("load result:\n", msg)
|
| 87 |
model.cuda()
|
| 88 |
model.eval()
|
| 89 |
print(f"Model = {str(model)}")
|
|
|
|
| 244 |
llama_config = "config/llama2/7B.json"
|
| 245 |
model_max_seq_len = 2048
|
| 246 |
# pretrained_path = "weights/7B_2048/consolidated.00-of-01.pth"
|
| 247 |
+
# pretrained_path = hf_hub_download(repo_id="csuhan/OneLLM-7B", filename="consolidated.00-of-01.pth")
|
| 248 |
+
pretrained_path = "csuhan/OneLLM-7B-hf"
|
| 249 |
+
ckpt_format = "consolidated.00-of-01.s{}.pth"
|
| 250 |
+
num_ckpts = 10
|
| 251 |
master_port = 23861
|
| 252 |
master_addr = "127.0.0.1"
|
| 253 |
dtype = "fp16"
|