RianLi commited on
Commit
1286204
·
verified ·
1 Parent(s): fe4323b

Delete test_model.py

Browse files
Files changed (1) hide show
  1. test_model.py +0 -70
test_model.py DELETED
@@ -1,70 +0,0 @@
1
- import torch
2
- from transformers import AutoModelForCausalLM, AutoTokenizer
3
- from peft import PeftModel
4
-
5
- # 加载基础模型和分词器
6
- model_name = "microsoft/DialoGPT-small"
7
- base_model = AutoModelForCausalLM.from_pretrained(
8
- model_name,
9
- torch_dtype=torch.float32,
10
- low_cpu_mem_usage=True,
11
- )
12
- tokenizer = AutoTokenizer.from_pretrained(model_name)
13
- tokenizer.pad_token = tokenizer.eos_token
14
-
15
- # 加载LoRA适配器
16
- model = PeftModel.from_pretrained(base_model, "./dialogpt-small-lora")
17
-
18
- # 测试函数
19
- def test_model(instruction, input_text):
20
- prompt = f"### Instruction:\n{instruction}\n\n### Input:\n{input_text}\n\n### Response:\n"
21
-
22
- # 编码输入
23
- inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True)
24
-
25
- # 生成响应
26
- with torch.no_grad():
27
- outputs = model.generate(
28
- input_ids=inputs['input_ids'],
29
- attention_mask=inputs['attention_mask'],
30
- max_new_tokens=50, # 限制新生成的token数量
31
- num_return_sequences=1,
32
- temperature=0.3, # 降低温度获得更确定的输出
33
- do_sample=True,
34
- pad_token_id=tokenizer.eos_token_id,
35
- eos_token_id=tokenizer.eos_token_id,
36
- repetition_penalty=1.1 # 减少重复
37
- )
38
-
39
- # 解码输出
40
- response = tokenizer.decode(outputs[0], skip_special_tokens=True)
41
-
42
- # 提取生成的部分
43
- generated_text = response[len(prompt):].strip()
44
-
45
- return generated_text
46
-
47
- # 测试示例
48
- if __name__ == "__main__":
49
- print("测试微调后的模型...")
50
- print("="*50)
51
-
52
- # 测试1:生成用户JSON对象
53
- instruction = "根据以下信息,生成一个用户JSON对象。"
54
- input_text = "用户ID是999,用户名是test_user,邮箱是test@example.com"
55
-
56
- result = test_model(instruction, input_text)
57
- print(f"指令: {instruction}")
58
- print(f"输入: {input_text}")
59
- print(f"输出: {result}")
60
- print("="*50)
61
-
62
- # 测试2:另一个示例
63
- instruction = "根据以下信息,生成一个用户JSON对象。"
64
- input_text = "用户ID是888,用户名是admin,邮箱是admin@company.com"
65
-
66
- result = test_model(instruction, input_text)
67
- print(f"指令: {instruction}")
68
- print(f"输入: {input_text}")
69
- print(f"输出: {result}")
70
- print("="*50)