File size: 3,859 Bytes
2c31fd4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
#!/usr/bin/env python3

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

# Load model and tokenizer (same as server.py)
model_name = "models/Llama-3.2-1B-Instruct"
tok = AutoTokenizer.from_pretrained(model_name)
lm = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.bfloat16,
    device_map="cuda",
).eval()

def chat_current(system_prompt: str, user_prompt: str) -> str:
    """
    Current implementation (same as server.py) - will show warnings
    """
    print("🔴 Running CURRENT implementation (with warnings)...")
    
    messages = [
        {"role": "system", "content": system_prompt},
        {"role": "user", "content": user_prompt},
    ]

    input_ids = tok.apply_chat_template(
        messages,
        add_generation_prompt=True,
        return_tensors="pt"
    ).to(lm.device)

    with torch.inference_mode():
        output_ids = lm.generate(
            input_ids,  # No attention_mask, no pad_token_id
            max_new_tokens=2048,
            do_sample=True,
            temperature=0.2,
            repetition_penalty=1.1,
            top_k=100,
            top_p=0.95,
        )

    answer = tok.decode(
        output_ids[0][input_ids.shape[-1]:],
        skip_special_tokens=True,
        clean_up_tokenization_spaces=True,
    )
    return answer.strip()


def chat_fixed(system_prompt: str, user_prompt: str) -> str:
    """
    Fixed implementation - proper attention mask and pad token
    """
    print("🟢 Running FIXED implementation (no warnings)...")
    
    messages = [
        {"role": "system", "content": system_prompt},
        {"role": "user", "content": user_prompt},
    ]

    # Get both input_ids and attention_mask
    inputs = tok.apply_chat_template(
        messages,
        add_generation_prompt=True,
        return_tensors="pt",
        return_dict=True  # Returns dict with input_ids and attention_mask
    )
    
    # Move to device
    input_ids = inputs["input_ids"].to(lm.device)
    attention_mask = inputs["attention_mask"].to(lm.device)

    with torch.inference_mode():
        output_ids = lm.generate(
            input_ids=input_ids,
            attention_mask=attention_mask,  # Proper attention mask
            pad_token_id=tok.eos_token_id,  # Explicit pad token
            max_new_tokens=2048,
            do_sample=True,
            temperature=0.2,
            repetition_penalty=1.1,
            top_k=100,
            top_p=0.95,
        )

    answer = tok.decode(
        output_ids[0][input_ids.shape[-1]:],
        skip_special_tokens=True,
        clean_up_tokenization_spaces=True,
    )
    return answer.strip()


def compare_generations():
    """Compare both implementations"""
    system_prompt = "You are a helpful assistant who tries to help answer the user's question."
    user_prompt = "Create a report on anxiety in work. How do I manage time and stress effectively?"
    
    print("=" * 60)
    print("COMPARING GENERATION METHODS")
    print("=" * 60)
    print(f"System: {system_prompt}")
    print(f"User: {user_prompt}")
    print("=" * 60)
    
    # Test current implementation
    print("\n" + "=" * 60)
    current_output = chat_current(system_prompt, user_prompt)
    print(f"CURRENT OUTPUT:\n{current_output}")
    
    print("\n" + "=" * 60)
    # Test fixed implementation  
    fixed_output = chat_fixed(system_prompt, user_prompt)
    print(f"FIXED OUTPUT:\n{fixed_output}")
    
    print("\n" + "=" * 60)
    print("COMPARISON:")
    print(f"Outputs are identical: {current_output == fixed_output}")
    print(f"Current length: {len(current_output)} chars")
    print(f"Fixed length: {len(fixed_output)} chars")


if __name__ == "__main__":
    # Set pad token for the fixed version
    if tok.pad_token is None:
        tok.pad_token = tok.eos_token
    
    compare_generations()