Spaces:
Paused
Paused
Update app.py
Browse files
app.py
CHANGED
|
@@ -121,26 +121,20 @@ def main():
|
|
| 121 |
|
| 122 |
parser = HfArgumentParser((
|
| 123 |
ModelArguments))
|
| 124 |
-
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
|
| 125 |
-
# If we pass only one argument to the script and it's the path to a json file,
|
| 126 |
-
# let's parse it to get our arguments.
|
| 127 |
-
model_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))[0]
|
| 128 |
-
else:
|
| 129 |
-
model_args = parser.parse_args_into_dataclasses()[0]
|
| 130 |
|
| 131 |
tokenizer = AutoTokenizer.from_pretrained(
|
| 132 |
"THUDM/chatglm-6b-int4", trust_remote_code=True)
|
| 133 |
config = AutoConfig.from_pretrained(
|
| 134 |
"MOSS550V/divination", trust_remote_code=True)
|
| 135 |
|
| 136 |
-
config.pre_seq_len =
|
| 137 |
-
config.prefix_projection =
|
| 138 |
|
| 139 |
ptuning_checkpoint = "MOSS550V/divination"
|
| 140 |
|
| 141 |
if ptuning_checkpoint is not None:
|
| 142 |
print(f"Loading prefix_encoder weight from {ptuning_checkpoint}")
|
| 143 |
-
model = AutoModel.from_pretrained("THUDM/chatglm-6b", config=config, trust_remote_code=True)
|
| 144 |
prefix_state_dict = torch.load(os.path.join(ptuning_checkpoint, "pytorch_model.bin"))
|
| 145 |
new_prefix_state_dict = {}
|
| 146 |
for k, v in prefix_state_dict.items():
|
|
@@ -150,14 +144,11 @@ def main():
|
|
| 150 |
else:
|
| 151 |
model = AutoModel.from_pretrained("THUDM/chatglm-6b-int4", config=config, trust_remote_code=True)
|
| 152 |
|
| 153 |
-
|
| 154 |
-
print(f"Quantized to {model_args.quantization_bit} bit")
|
| 155 |
-
model = model.quantize(model_args.quantization_bit)
|
| 156 |
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
model.transformer.prefix_encoder.float()
|
| 161 |
|
| 162 |
model = model.eval()
|
| 163 |
demo.queue().launch(share=False, inbrowser=True)
|
|
|
|
| 121 |
|
| 122 |
parser = HfArgumentParser((
|
| 123 |
ModelArguments))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 124 |
|
| 125 |
tokenizer = AutoTokenizer.from_pretrained(
|
| 126 |
"THUDM/chatglm-6b-int4", trust_remote_code=True)
|
| 127 |
config = AutoConfig.from_pretrained(
|
| 128 |
"MOSS550V/divination", trust_remote_code=True)
|
| 129 |
|
| 130 |
+
config.pre_seq_len = 128
|
| 131 |
+
config.prefix_projection = false
|
| 132 |
|
| 133 |
ptuning_checkpoint = "MOSS550V/divination"
|
| 134 |
|
| 135 |
if ptuning_checkpoint is not None:
|
| 136 |
print(f"Loading prefix_encoder weight from {ptuning_checkpoint}")
|
| 137 |
+
model = AutoModel.from_pretrained("THUDM/chatglm-6b-int4", config=config, trust_remote_code=True)
|
| 138 |
prefix_state_dict = torch.load(os.path.join(ptuning_checkpoint, "pytorch_model.bin"))
|
| 139 |
new_prefix_state_dict = {}
|
| 140 |
for k, v in prefix_state_dict.items():
|
|
|
|
| 144 |
else:
|
| 145 |
model = AutoModel.from_pretrained("THUDM/chatglm-6b-int4", config=config, trust_remote_code=True)
|
| 146 |
|
| 147 |
+
model = model.quantize(4)
|
|
|
|
|
|
|
| 148 |
|
| 149 |
+
# P-tuning v2
|
| 150 |
+
model = model.half()
|
| 151 |
+
model.transformer.prefix_encoder.float()
|
|
|
|
| 152 |
|
| 153 |
model = model.eval()
|
| 154 |
demo.queue().launch(share=False, inbrowser=True)
|