File size: 12,453 Bytes
27c46c6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
"""
Interactive Chat Interface for Testing Fine-tuned Japanese Counseling Model
"""

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import os
import warnings
from datetime import datetime
import json

warnings.filterwarnings('ignore')

class CounselorChatInterface:
    def __init__(self, model_path: str = "./merged_counselor_model"):
        """
        Initialize the chat interface with the fine-tuned model
        
        Args:
            model_path: Path to the fine-tuned model
        """
        self.model_path = model_path
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        
        print("="*80)
        print("๐ŸŽŒ Japanese Counseling Model Chat Interface")
        print("="*80)
        print(f"๐Ÿ“ Device: {self.device}")
        
        if self.device.type == "cuda":
            print(f"   GPU: {torch.cuda.get_device_name(0)}")
            print(f"   Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
        
        self.load_model()
        self.conversation_history = []
        
    def load_model(self):
        """Load the fine-tuned model and tokenizer"""
        print(f"\n๐Ÿค– Loading model from {self.model_path}...")
        
        try:
            # Load tokenizer
            self.tokenizer = AutoTokenizer.from_pretrained(
                self.model_path,
                local_files_only=True
            )
            
            # Set padding token if not set
            if self.tokenizer.pad_token is None:
                self.tokenizer.pad_token = self.tokenizer.eos_token
            
            # Load model
            self.model = AutoModelForCausalLM.from_pretrained(
                self.model_path,
                torch_dtype=torch.float16 if self.device.type == "cuda" else torch.float32,
                device_map="auto" if self.device.type == "cuda" else None,
                local_files_only=True,
                trust_remote_code=True
            )
            
            self.model.eval()
            print("โœ… Model loaded successfully!")
            
        except Exception as e:
            print(f"โŒ Error loading model: {e}")
            print("Trying alternative loading method...")
            
            # Try loading with base tokenizer
            try:
                self.tokenizer = AutoTokenizer.from_pretrained("gpt2")
                if self.tokenizer.pad_token is None:
                    self.tokenizer.pad_token = self.tokenizer.eos_token
                
                self.model = AutoModelForCausalLM.from_pretrained(
                    self.model_path,
                    torch_dtype=torch.float16 if self.device.type == "cuda" else torch.float32,
                    local_files_only=True
                )
                self.model = self.model.to(self.device)
                self.model.eval()
                print("โœ… Model loaded with fallback tokenizer!")
            except Exception as e2:
                print(f"โŒ Failed to load model: {e2}")
                raise
    
    def generate_response(self, user_input: str, 
                         temperature: float = 0,
                         max_length: int = 200,
                         use_context: bool = True) -> str:
        """
        Generate a counseling response
        
        Args:
            user_input: User's message
            temperature: Generation temperature (0.1-1.0)
            max_length: Maximum response length
            use_context: Whether to use conversation history
            
        Returns:
            Generated response
        """
        # Format the prompt
        if use_context and len(self.conversation_history) > 0:
            # Include recent context
            context = "\n".join(self.conversation_history[-4:])  # Last 2 exchanges
            prompt = f"""### Instruction:
ใ‚ใชใŸใฏๆ€ใ„ใ‚„ใ‚Šใฎใ‚ใ‚‹ๅฟƒ็†ใ‚ซใ‚ฆใƒณใ‚ปใƒฉใƒผใงใ™ใ€‚
ใ‚ฏใƒฉใ‚คใ‚ขใƒณใƒˆใฎๆ„Ÿๆƒ…ใ‚’็†่งฃใ—ใ€ๅ…ฑๆ„Ÿ็š„ใงๆ”ฏๆด็š„ใชๅฟœ็ญ”ใ‚’ๆไพ›ใ—ใฆใใ ใ•ใ„ใ€‚

### Context:
{context}

### Input:
{user_input}

### Response:
"""
        else:
            prompt = f"""### Instruction:
ใ‚ใชใŸใฏๆ€ใ„ใ‚„ใ‚Šใฎใ‚ใ‚‹ๅฟƒ็†ใ‚ซใ‚ฆใƒณใ‚ปใƒฉใƒผใงใ™ใ€‚
ใ‚ฏใƒฉใ‚คใ‚ขใƒณใƒˆใฎๆ„Ÿๆƒ…ใ‚’็†่งฃใ—ใ€ๅ…ฑๆ„Ÿ็š„ใงๆ”ฏๆด็š„ใชๅฟœ็ญ”ใ‚’ๆไพ›ใ—ใฆใใ ใ•ใ„ใ€‚

### Input:
{user_input}

### Response:
"""
        
        # Tokenize
        inputs = self.tokenizer(
            prompt,
            return_tensors="pt",
            truncation=True,
            max_length=512
        )
        
        if self.device.type == "cuda":
            inputs = {k: v.cuda() for k, v in inputs.items()}
        
        # Generate
        try:
            with torch.no_grad():
                with torch.cuda.amp.autocast() if self.device.type == "cuda" else torch.autocast("cpu"):
                    outputs = self.model.generate(
                        **inputs,
                        max_new_tokens=max_length,
                        temperature=temperature,
                        do_sample=True,
                        top_p=0.9,
                        top_k=50,
                        repetition_penalty=1.1,
                        pad_token_id=self.tokenizer.pad_token_id,
                        eos_token_id=self.tokenizer.eos_token_id
                    )
            
            # Decode
            full_response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
            
            # Extract only the response part
            if "### Response:" in full_response:
                response = full_response.split("### Response:")[-1].strip()
            else:
                response = full_response[len(prompt):].strip()
            
            return response
            
        except Exception as e:
            print(f"Error generating response: {e}")
            return "็”ณใ—่จณใ”ใ–ใ„ใพใ›ใ‚“ใ€‚ๅฟœ็ญ”ใฎ็”Ÿๆˆไธญใซใ‚จใƒฉใƒผใŒ็™บ็”Ÿใ—ใพใ—ใŸใ€‚"
    
    def chat(self):
        """Start interactive chat session"""
        print("\n" + "="*80)
        print("๐Ÿ’ฌ ใƒใƒฃใƒƒใƒˆใ‚’้–‹ๅง‹ใ—ใพใ™ (Chat session started)")
        print("="*80)
        print("Commands:")
        print("  /quit or /exit - ็ต‚ไบ† (Exit)")
        print("  /clear - ไผš่ฉฑๅฑฅๆญดใ‚’ใ‚ฏใƒชใ‚ข (Clear conversation history)")
        print("  /save - ไผš่ฉฑใ‚’ไฟๅญ˜ (Save conversation)")
        print("  /temp <value> - ๆธฉๅบฆใƒ‘ใƒฉใƒกใƒผใ‚ฟใ‚’่จญๅฎš (Set temperature, e.g., /temp 0.8)")
        print("  /context on/off - ใ‚ณใƒณใƒ†ใ‚ญใ‚นใƒˆไฝฟ็”จใฎๅˆ‡ใ‚Šๆ›ฟใˆ (Toggle context usage)")
        print("-"*80)
        
        temperature = 0.1
        use_context = True
        
        while True:
            try:
                # Get user input
                user_input = input("\n๐Ÿ‘ค You: ").strip()
                
                # Check for commands
                if user_input.lower() in ['/quit', '/exit', '/q']:
                    print("\n๐Ÿ‘‹ ใ•ใ‚ˆใ†ใชใ‚‰๏ผ(Goodbye!)")
                    break
                
                elif user_input.lower() == '/clear':
                    self.conversation_history = []
                    print("โœ… ไผš่ฉฑๅฑฅๆญดใ‚’ใ‚ฏใƒชใ‚ขใ—ใพใ—ใŸ (Conversation history cleared)")
                    continue
                
                elif user_input.lower() == '/save':
                    self.save_conversation()
                    continue
                
                elif user_input.lower().startswith('/temp'):
                    try:
                        temperature = float(user_input.split()[1])
                        temperature = 0.1 # max(0.1, min(, temperature))
                        print(f"โœ… Temperature set to {temperature}")
                    except:
                        print("โŒ Invalid temperature. Use: /temp 0.7")
                    continue
                
                elif user_input.lower().startswith('/context'):
                    try:
                        setting = user_input.split()[1].lower()
                        use_context = setting == 'on'
                        print(f"โœ… Context {'enabled' if use_context else 'disabled'}")
                    except:
                        print("โŒ Use: /context on or /context off")
                    continue
                
                elif user_input.startswith('/'):
                    print("โŒ Unknown command")
                    continue
                
                # Generate response
                print("\n๐Ÿค– Counselor: ", end="", flush=True)
                response = self.generate_response(
                    user_input, 
                    temperature=temperature,
                    use_context=use_context
                )
                print(response)
                
                # Add to history
                self.conversation_history.append(f"Client: {user_input}")
                self.conversation_history.append(f"Counselor: {response}")
                
            except KeyboardInterrupt:
                print("\n\n๐Ÿ‘‹ ใ•ใ‚ˆใ†ใชใ‚‰๏ผ(Goodbye!)")
                break
            except Exception as e:
                print(f"\nโŒ Error: {e}")
                continue
    
    def save_conversation(self):
        """Save the conversation to a file"""
        if not self.conversation_history:
            print("โŒ No conversation to save")
            return
        
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        filename = f"conversation_{timestamp}.json"
        
        conversation_data = {
            "timestamp": timestamp,
            "model_path": self.model_path,
            "conversation": self.conversation_history
        }
        
        with open(filename, 'w', encoding='utf-8') as f:
            json.dump(conversation_data, f, ensure_ascii=False, indent=2)
        
        print(f"โœ… Conversation saved to {filename}")
    
    def test_responses(self):
        """Test the model with predefined inputs"""
        print("\n" + "="*80)
        print("๐Ÿงช Testing Model Responses")
        print("="*80)
        
        test_inputs = [
            "ใ“ใ‚“ใซใกใฏใ€‚ๆœ€่ฟ‘ใ‚นใƒˆใƒฌใ‚นใ‚’ๆ„Ÿใ˜ใฆใ„ใพใ™ใ€‚",
            "ไป•ไบ‹ใŒใ†ใพใใ„ใ‹ใชใใฆๆ‚ฉใ‚“ใงใ„ใพใ™ใ€‚",
            "ไบบ้–“้–ขไฟ‚ใงๅ›ฐใฃใฆใ„ใพใ™ใ€‚ใฉใ†ใ™ใ‚Œใฐใ„ใ„ใงใ—ใ‚‡ใ†ใ‹ใ€‚",
            "ๅฐ†ๆฅใŒไธๅฎ‰ใง็œ ใ‚Œใพใ›ใ‚“ใ€‚",
            "่‡ชๅˆ†ใซ่‡ชไฟกใŒๆŒใฆใพใ›ใ‚“ใ€‚",
            "ๅฎถๆ—ใจใฎ้–ขไฟ‚ใงๆ‚ฉใ‚“ใงใ„ใพใ™ใ€‚",
            "ๆฏŽๆ—ฅใŒ่พ›ใ„ใงใ™ใ€‚",
            "่ชฐใซใ‚‚็›ธ่ซ‡ใงใใพใ›ใ‚“ใ€‚"
        ]
        
        print("\nTesting with different temperature settings:\n")
        
        for temp in [0, 0.1]:
            print(f"\n๐ŸŒก๏ธ Temperature: {temp}")
            print("-"*60)
            
            for i, test_input in enumerate(test_inputs[:3], 1):
                print(f"\n{i}. Input: {test_input}")
                response = self.generate_response(test_input, temperature=temp, use_context=False)
                print(f"   Response: {response[:200]}...")
                print()
        
        print("="*80)


def main():
    """Main function"""
    import argparse
    
    parser = argparse.ArgumentParser(description='Chat with fine-tuned counseling model')
    parser.add_argument('--model_path', type=str, default='./merged_counselor_mode_2b',
                       help='Path to the fine-tuned model')
    parser.add_argument('--test_only', action='store_true',
                       help='Only run test responses without chat')
    
    args = parser.parse_args()
    
    # Check if model exists
    if not os.path.exists(args.model_path):
        print(f"โŒ Model not found at {args.model_path}")
        print("\nAvailable models:")
        for item in os.listdir('.'):
            if 'model' in item.lower() and os.path.isdir(item):
                print(f"  - {item}")
        return
    
    try:
        # Initialize chat interface
        chat = CounselorChatInterface(model_path=args.model_path)
        
        if args.test_only:
            # Run tests only
            chat.test_responses()
        else:
            # Start interactive chat
            chat.chat()
            
    except Exception as e:
        print(f"โŒ Error: {e}")
        import traceback
        traceback.print_exc()


if __name__ == "__main__":
    main()