Spaces:
Sleeping
Sleeping
File size: 4,781 Bytes
ac79c4a 3b275e4 ac79c4a 3b275e4 ac79c4a | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 | #!/usr/bin/python3
# -*- coding: utf-8 -*-
"""
python step_3_inter_model.py \
--model_name /root/autodl-tmp/OpenMiniMind/temp/trainer_output/Qwen3-8B-sft-deepspeed/checkpoint-100
python step_3_inter_model.py \
--model_name /root/autodl-tmp/OpenMiniMind/temp/trainer_output/Qwen3-8B-sft-deepspeed/checkpoint-100/Qwen3-8B-sft-deepspeed
只保存必要的模型文件,不保存优化器参数等,以便提交模型仓库。
python3 zero_to_fp32.py --max_shard_size 4GB --safe_serialization ./ ./Qwen3-8B-sft-deepspeed
mv config.json ./Qwen3-8B-sft-deepspeed
mv generation_config.json ./Qwen3-8B-sft-deepspeed
mv merges.txt ./Qwen3-8B-sft-deepspeed
mv tokenizer.json ./Qwen3-8B-sft-deepspeed
mv tokenizer_config.json ./Qwen3-8B-sft-deepspeed
mv vocab.json ./Qwen3-8B-sft-deepspeed
mv added_tokens.json ./Qwen3-8B-sft-deepspeed
mv chat_template.jinja ./Qwen3-8B-sft-deepspeed
mv special_tokens_map.json ./Qwen3-8B-sft-deepspeed
"""
import argparse
import os
from pathlib import Path
import platform
if platform.system() in ("Windows", "Darwin"):
from project_settings import project_path, temp_directory
else:
project_path = os.path.abspath("../../../")
project_path = Path(project_path)
temp_directory = Path("/root/autodl-tmp/OpenMiniMind/temp")
# from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from modelscope import AutoConfig, AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from transformers import TextStreamer
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--model_name",
# default="Qwen/Qwen3-8B",
default=(temp_directory / "trained_models/Qwen3-8B-sft-deepspeed/checkpoint-100").as_posix(),
type=str
),
parser.add_argument(
"--model_cache_dir",
default=(temp_directory / "hub_models").as_posix(),
type=str
),
parser.add_argument(
"--max_new_tokens",
default=1024, # 8192, 128
type=int, help="最大生成长度(注意:并非模型实际长文本能力)"
),
parser.add_argument("--top_p", default=0.85, type=float, help="nucleus采样阈值(0-1)"),
parser.add_argument("--temperature", default=0.85, type=float, help="生成温度,控制随机性(0-1,越大越随机)"),
parser.add_argument(
"--num_workers",
default=None if platform.system() in ("Windows", "Darwin") else os.cpu_count() // 2,
type=int
),
args = parser.parse_args()
return args
def main():
args = get_args()
model = AutoModelForCausalLM.from_pretrained(
pretrained_model_name_or_path=args.model_name,
quantization_config=None,
device_map="auto", # 启用多GPU拆分,只在单卡装不下模型时使用。
trust_remote_code=True,
cache_dir=args.model_cache_dir,
)
tokenizer = AutoTokenizer.from_pretrained(
pretrained_model_name_or_path=args.model_name,
trust_remote_code=True,
cache_dir=args.model_cache_dir,
)
model.eval()
# print(model)
messages = [
{
"role": "user",
"content": "关键词识别:\n梯度功能材料是基于一种全新的材料设计概念而开发的新型功能材料.陶瓷-金属FGM的主要结构特点是各梯度层由不同体积浓度的陶瓷和金属组成,材料在升温和降温过程中宏观梯度层间产生热应力,每一梯度层中细观增强相和基体的热物性失配将产生单层热应力,从而导致材料整体的破坏.采用云纹干涉法,对具有四个梯度层的SiC/A1梯度功能材料分别在机载、热载及两者共同作用下进行了应变测试,分别得到了这三种情况下每梯度层同一位置的纵向应变,横向应变和剪应变值."
}
]
format_messages = tokenizer.apply_chat_template(
messages,
tokenize=False, # 训练时部分词,true返回的是张量
add_generation_prompt=True, # 训练期间要关闭,如果是推理则设为True
)
# 4、调用tokenizer得到input
inputs = tokenizer(format_messages, return_tensors="pt").to(model.device)
# 5、调用model.generate()
generated_ids = model.generate(
**inputs,
max_new_tokens=args.max_new_tokens, do_sample=True,
streamer=TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True),
pad_token_id=tokenizer.pad_token_id, eos_token_id=tokenizer.eos_token_id,
top_p=args.top_p, temperature=args.temperature, repetition_penalty=1.0,
)
response = tokenizer.decode(generated_ids[0][len(inputs["input_ids"][0]):], skip_special_tokens=True)
print(f"response: {response}")
return
if __name__ == "__main__":
main()
|