File size: 9,664 Bytes
47b302d
 
 
279253e
9e8bc15
47b302d
 
 
 
 
279253e
 
 
 
47b302d
 
 
 
 
 
 
 
 
279253e
 
47b302d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8c0f998
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47b302d
8c0f998
 
 
 
47b302d
8c0f998
 
 
 
 
 
 
 
 
 
 
 
 
 
47b302d
8c0f998
47b302d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9e8bc15
47b302d
 
279253e
47b302d
9e8bc15
47b302d
 
 
9e8bc15
47b302d
 
 
9e8bc15
47b302d
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
from typing import List, Tuple
import os

# Use a very small model for testing
class UniversalChatModel:
    def __init__(self, model_name: str):
        self.model_name = model_name
        print(f"Loading tokenizer for {model_name}...")
        self.tokenizer = AutoTokenizer.from_pretrained(
            model_name,
            token=os.getenv("HF_TOKEN")
        )
        
        # Set padding token
        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token or "<|endoftext|>"
            
        print(f"Loading model {model_name}...")
        self.model = AutoModelForCausalLM.from_pretrained(
            model_name,
            torch_dtype=torch.float16,
            device_map="auto",
            token=os.getenv("HF_TOKEN")
        )
        
        print("Model loaded successfully!")
        
    def format_prompt_fallback(self, messages: List[dict]) -> str:
        """Universal ChatML format for models without chat templates"""
        chatml = ""
        for message in messages:
            role = message["role"]
            content = message["content"]
            chatml += f"<|im_start|>{role}\n{content}<|im_end|>\n"
        chatml += "<|im_start|>assistant\n"
        return chatml
        
    def build_messages(self, history: List[Tuple[str, str]], current_message: str, system_prompt: str = None) -> List[dict]:
        """Build universal message format"""
        messages = []
        
        # Add system prompt
        if system_prompt:
            messages.append({"role": "system", "content": system_prompt})
        
        # Add history
        for user_msg, assistant_msg in history:
            messages.append({"role": "user", "content": user_msg})
            messages.append({"role": "assistant", "content": assistant_msg})
        
        # Add current message
        messages.append({"role": "user", "content": current_message})
        
        return messages
        
    def format_prompt(self, messages: List[dict]) -> str:
        """Format prompt using model's chat template or fallback"""
        # Try model's built-in chat template
        if hasattr(self.tokenizer, 'chat_template') and self.tokenizer.chat_template:
            try:
                prompt = self.tokenizer.apply_chat_template(
                    messages, 
                    tokenize=False, 
                    add_generation_prompt=True
                )
                return prompt
            except Exception as e:
                print(f"Warning: Failed to use chat template: {e}")
                pass
        
        # Fallback to ChatML
        return self.format_prompt_fallback(messages)
        
    def extract_response(self, prompt: str, generated_text: str) -> str:
        """Enhanced universal response extraction with comprehensive token cleanup"""
        
        # Define token patterns to clean
        token_patterns = [
            # ChatML tokens
            "<|im_start|>", "<|im_end|>", 
            # Common end tokens
            "</s>", "<eos>", 
            # Model-specific tokens (from your example)
            "[/inst]", "[inst]", "<</sys>>", "<</inst>>",
            # Additional patterns
            "[/assistant]", "[assistant]", "[/user]", "[user]"
        ]
        
        def clean_response(response: str) -> str:
            """Clean all known token patterns from response"""
            for pattern in token_patterns:
                response = response.replace(pattern, "")
            # Clean up extra whitespace and newlines
            response = response.strip()
            # Remove leading/trailing punctuation that might be from tokens
            response = response.strip('.,!?;:\n\r\t ')
            return response
        
        def extract_chatml_response(text: str) -> str | None:
            """Extract response using ChatML markers"""
            if "<|im_start|>assistant\n" in text:
                parts = text.split("<|im_start|>assistant\n")
                if len(parts) > 1:
                    response = parts[-1]
                    if "<|im_end|>" in response:
                        response = response.split("<|im_end|>")[0]
                    return clean_response(response)
            return None
        
        def extract_inst_response(text: str) -> str | None:
            """Extract response using inst/inst pattern (from your example)"""
            # Look for [inst] or [/inst] patterns
            inst_patterns = ["[inst]", "[/inst]"]
            for pattern in inst_patterns:
                if pattern in text.lower():
                    parts = text.lower().split(pattern)
                    if len(parts) > 1:
                        response = parts[-1]
                        return clean_response(response)
            return None
        
        def extract_after_prompt(text: str, prompt: str) -> str | None:
            """Extract response that comes directly after prompt"""
            if text.startswith(prompt):
                response = text[len(prompt):]
                return clean_response(response)
            return None
        
        def extract_last_assistant_message(text: str) -> str:
            """Fallback: find the last assistant-like message"""
            # Look for various assistant indicators
            assistant_indicators = ["assistant", "[inst]", "[/inst]"]
            text_lower = text.lower()
            
            best_response = ""
            for indicator in assistant_indicators:
                if indicator in text_lower:
                    parts = text_lower.split(indicator)
                    if len(parts) > 1:
                        candidate = parts[-1]
                        if len(candidate) > len(best_response):
                            best_response = candidate
            
            if best_response:
                return clean_response(best_response)
            
            # Final fallback: just clean the whole thing
            return clean_response(text)
        
        print(f"\n\n----- EXTRACTION DEBUG -----\n")
        print(f"Generated text length: {len(generated_text)}")
        print(f"Prompt length: {len(prompt)}")
        print(f"Generated starts with prompt: {generated_text.startswith(prompt)}")
        
        # Try extraction methods in order of reliability
        response = None
        
        # Method 1: ChatML extraction
        response = extract_chatml_response(generated_text)
        if response:
            print("Used ChatML extraction")
        else:
            # Method 2: inst pattern extraction (from your example)
            response = extract_inst_response(generated_text)
            if response:
                print("Used inst pattern extraction")
            else:
                # Method 3: Extract after prompt
                response = extract_after_prompt(generated_text, prompt)
                if response:
                    print("Used extract-after-prompt method")
                else:
                    # Method 4: Fallback to last assistant message
                    response = extract_last_assistant_message(generated_text)
                    print("Used fallback extraction")
        
        print(f"Extracted response: '{response}'")
        print(f"Response length: {len(response)}")
        print(f"----- END EXTRACTION DEBUG -----\n\n")
        
        return response or ""
        
    def generate(self, message: str, history: List[Tuple[str, str]] | None = None, system_prompt: str | None = None) -> str:
        """Generate response using universal chat template system"""
        if history is None:
            history = []
            
        # Build messages
        messages = self.build_messages(history, message, system_prompt)
        
        # Format prompt
        prompt = self.format_prompt(messages)
        
        print(f"\n\n----- PROMPT -----\n{prompt}\n-----------------\n\n")
        
        # Tokenize
        inputs = self.tokenizer(prompt, return_tensors="pt", padding=True)
        # Move to model device
        inputs = {k: v.to(self.model.device) for k, v in inputs.items()}
        
        # Generate
        generation_config = {
            "max_new_tokens": 150,
            "do_sample": True,
            "temperature": 0.7,
            "pad_token_id": self.tokenizer.eos_token_id,
            "eos_token_id": self.tokenizer.eos_token_id,
        }
        
        with torch.no_grad():
            outputs = self.model.generate(
                **inputs,
                **generation_config
            )
        
        # Decode
        generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
        
        print(f"\n\n----- GENERATED -----\n{generated_text}\n-------------------\n\n")
        
        # Extract response
        response = self.extract_response(prompt, generated_text)
        
        return response

# Initialize with a tiny model for testing
# MODEL_NAME = "HuggingFaceH4/tiny-random-LlamaForCausalLM"
MODEL_NAME = "meta-llama/Llama-2-7b-chat-hf"
SYSTEM_PROMPT = "You are a helpful assistant."

# Create global model instance
print("Creating model instance...")
chat_model = UniversalChatModel(MODEL_NAME)

def generate(message: str, history: List[Tuple[str, str]]) -> str:
    """Generate response using universal chat model"""
    return chat_model.generate(message, history, SYSTEM_PROMPT)

if __name__ == "__main__":
    # Quick test
    print("Testing generation...")
    reply = generate("What is 2+2?", [])
    print(f"Final response: {reply}")