File size: 2,343 Bytes
fe4323b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel

# 加载基础模型和分词器
model_name = "microsoft/DialoGPT-small"
base_model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.float32,
    low_cpu_mem_usage=True,
)
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token

# 加载LoRA适配器
model = PeftModel.from_pretrained(base_model, "./dialogpt-small-lora")

# 测试函数
def test_model(instruction, input_text):
    prompt = f"### Instruction:\n{instruction}\n\n### Input:\n{input_text}\n\n### Response:\n"
    
    # 编码输入
    inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True)
    
    # 生成响应
    with torch.no_grad():
        outputs = model.generate(
            input_ids=inputs['input_ids'],
            attention_mask=inputs['attention_mask'],
            max_new_tokens=50,  # 限制新生成的token数量
            num_return_sequences=1,
            temperature=0.3,  # 降低温度获得更确定的输出
            do_sample=True,
            pad_token_id=tokenizer.eos_token_id,
            eos_token_id=tokenizer.eos_token_id,
            repetition_penalty=1.1  # 减少重复
        )
    
    # 解码输出
    response = tokenizer.decode(outputs[0], skip_special_tokens=True)
    
    # 提取生成的部分
    generated_text = response[len(prompt):].strip()
    
    return generated_text

# 测试示例
if __name__ == "__main__":
    print("测试微调后的模型...")
    print("="*50)
    
    # 测试1:生成用户JSON对象
    instruction = "根据以下信息,生成一个用户JSON对象。"
    input_text = "用户ID是999,用户名是test_user,邮箱是test@example.com"
    
    result = test_model(instruction, input_text)
    print(f"指令: {instruction}")
    print(f"输入: {input_text}")
    print(f"输出: {result}")
    print("="*50)
    
    # 测试2:另一个示例
    instruction = "根据以下信息,生成一个用户JSON对象。"
    input_text = "用户ID是888,用户名是admin,邮箱是admin@company.com"
    
    result = test_model(instruction, input_text)
    print(f"指令: {instruction}")
    print(f"输入: {input_text}")
    print(f"输出: {result}")
    print("="*50)