""" Interactive Chat Interface for Testing Fine-tuned Japanese Counseling Model """ import torch from transformers import AutoModelForCausalLM, AutoTokenizer import os import warnings from datetime import datetime import json warnings.filterwarnings('ignore') class CounselorChatInterface: def __init__(self, model_path: str = "./merged_counselor_model"): """ Initialize the chat interface with the fine-tuned model Args: model_path: Path to the fine-tuned model """ self.model_path = model_path self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print("="*80) print("๐ŸŽŒ Japanese Counseling Model Chat Interface") print("="*80) print(f"๐Ÿ“ Device: {self.device}") if self.device.type == "cuda": print(f" GPU: {torch.cuda.get_device_name(0)}") print(f" Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB") self.load_model() self.conversation_history = [] def load_model(self): """Load the fine-tuned model and tokenizer""" print(f"\n๐Ÿค– Loading model from {self.model_path}...") try: # Load tokenizer self.tokenizer = AutoTokenizer.from_pretrained( self.model_path, local_files_only=True ) # Set padding token if not set if self.tokenizer.pad_token is None: self.tokenizer.pad_token = self.tokenizer.eos_token # Load model self.model = AutoModelForCausalLM.from_pretrained( self.model_path, torch_dtype=torch.float16 if self.device.type == "cuda" else torch.float32, device_map="auto" if self.device.type == "cuda" else None, local_files_only=True, trust_remote_code=True ) self.model.eval() print("โœ… Model loaded successfully!") except Exception as e: print(f"โŒ Error loading model: {e}") print("Trying alternative loading method...") # Try loading with base tokenizer try: self.tokenizer = AutoTokenizer.from_pretrained("gpt2") if self.tokenizer.pad_token is None: self.tokenizer.pad_token = self.tokenizer.eos_token self.model = AutoModelForCausalLM.from_pretrained( self.model_path, torch_dtype=torch.float16 if self.device.type == "cuda" else torch.float32, local_files_only=True ) self.model = self.model.to(self.device) self.model.eval() print("โœ… Model loaded with fallback tokenizer!") except Exception as e2: print(f"โŒ Failed to load model: {e2}") raise def generate_response(self, user_input: str, temperature: float = 0, max_length: int = 200, use_context: bool = True) -> str: """ Generate a counseling response Args: user_input: User's message temperature: Generation temperature (0.1-1.0) max_length: Maximum response length use_context: Whether to use conversation history Returns: Generated response """ # Format the prompt if use_context and len(self.conversation_history) > 0: # Include recent context context = "\n".join(self.conversation_history[-4:]) # Last 2 exchanges prompt = f"""### Instruction: ใ‚ใชใŸใฏๆ€ใ„ใ‚„ใ‚Šใฎใ‚ใ‚‹ๅฟƒ็†ใ‚ซใ‚ฆใƒณใ‚ปใƒฉใƒผใงใ™ใ€‚ ใ‚ฏใƒฉใ‚คใ‚ขใƒณใƒˆใฎๆ„Ÿๆƒ…ใ‚’็†่งฃใ—ใ€ๅ…ฑๆ„Ÿ็š„ใงๆ”ฏๆด็š„ใชๅฟœ็ญ”ใ‚’ๆไพ›ใ—ใฆใใ ใ•ใ„ใ€‚ ### Context: {context} ### Input: {user_input} ### Response: """ else: prompt = f"""### Instruction: ใ‚ใชใŸใฏๆ€ใ„ใ‚„ใ‚Šใฎใ‚ใ‚‹ๅฟƒ็†ใ‚ซใ‚ฆใƒณใ‚ปใƒฉใƒผใงใ™ใ€‚ ใ‚ฏใƒฉใ‚คใ‚ขใƒณใƒˆใฎๆ„Ÿๆƒ…ใ‚’็†่งฃใ—ใ€ๅ…ฑๆ„Ÿ็š„ใงๆ”ฏๆด็š„ใชๅฟœ็ญ”ใ‚’ๆไพ›ใ—ใฆใใ ใ•ใ„ใ€‚ ### Input: {user_input} ### Response: """ # Tokenize inputs = self.tokenizer( prompt, return_tensors="pt", truncation=True, max_length=512 ) if self.device.type == "cuda": inputs = {k: v.cuda() for k, v in inputs.items()} # Generate try: with torch.no_grad(): with torch.cuda.amp.autocast() if self.device.type == "cuda" else torch.autocast("cpu"): outputs = self.model.generate( **inputs, max_new_tokens=max_length, temperature=temperature, do_sample=True, top_p=0.9, top_k=50, repetition_penalty=1.1, pad_token_id=self.tokenizer.pad_token_id, eos_token_id=self.tokenizer.eos_token_id ) # Decode full_response = self.tokenizer.decode(outputs[0], skip_special_tokens=True) # Extract only the response part if "### Response:" in full_response: response = full_response.split("### Response:")[-1].strip() else: response = full_response[len(prompt):].strip() return response except Exception as e: print(f"Error generating response: {e}") return "็”ณใ—่จณใ”ใ–ใ„ใพใ›ใ‚“ใ€‚ๅฟœ็ญ”ใฎ็”Ÿๆˆไธญใซใ‚จใƒฉใƒผใŒ็™บ็”Ÿใ—ใพใ—ใŸใ€‚" def chat(self): """Start interactive chat session""" print("\n" + "="*80) print("๐Ÿ’ฌ ใƒใƒฃใƒƒใƒˆใ‚’้–‹ๅง‹ใ—ใพใ™ (Chat session started)") print("="*80) print("Commands:") print(" /quit or /exit - ็ต‚ไบ† (Exit)") print(" /clear - ไผš่ฉฑๅฑฅๆญดใ‚’ใ‚ฏใƒชใ‚ข (Clear conversation history)") print(" /save - ไผš่ฉฑใ‚’ไฟๅญ˜ (Save conversation)") print(" /temp - ๆธฉๅบฆใƒ‘ใƒฉใƒกใƒผใ‚ฟใ‚’่จญๅฎš (Set temperature, e.g., /temp 0.8)") print(" /context on/off - ใ‚ณใƒณใƒ†ใ‚ญใ‚นใƒˆไฝฟ็”จใฎๅˆ‡ใ‚Šๆ›ฟใˆ (Toggle context usage)") print("-"*80) temperature = 0.1 use_context = True while True: try: # Get user input user_input = input("\n๐Ÿ‘ค You: ").strip() # Check for commands if user_input.lower() in ['/quit', '/exit', '/q']: print("\n๐Ÿ‘‹ ใ•ใ‚ˆใ†ใชใ‚‰๏ผ(Goodbye!)") break elif user_input.lower() == '/clear': self.conversation_history = [] print("โœ… ไผš่ฉฑๅฑฅๆญดใ‚’ใ‚ฏใƒชใ‚ขใ—ใพใ—ใŸ (Conversation history cleared)") continue elif user_input.lower() == '/save': self.save_conversation() continue elif user_input.lower().startswith('/temp'): try: temperature = float(user_input.split()[1]) temperature = 0.1 # max(0.1, min(, temperature)) print(f"โœ… Temperature set to {temperature}") except: print("โŒ Invalid temperature. Use: /temp 0.7") continue elif user_input.lower().startswith('/context'): try: setting = user_input.split()[1].lower() use_context = setting == 'on' print(f"โœ… Context {'enabled' if use_context else 'disabled'}") except: print("โŒ Use: /context on or /context off") continue elif user_input.startswith('/'): print("โŒ Unknown command") continue # Generate response print("\n๐Ÿค– Counselor: ", end="", flush=True) response = self.generate_response( user_input, temperature=temperature, use_context=use_context ) print(response) # Add to history self.conversation_history.append(f"Client: {user_input}") self.conversation_history.append(f"Counselor: {response}") except KeyboardInterrupt: print("\n\n๐Ÿ‘‹ ใ•ใ‚ˆใ†ใชใ‚‰๏ผ(Goodbye!)") break except Exception as e: print(f"\nโŒ Error: {e}") continue def save_conversation(self): """Save the conversation to a file""" if not self.conversation_history: print("โŒ No conversation to save") return timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") filename = f"conversation_{timestamp}.json" conversation_data = { "timestamp": timestamp, "model_path": self.model_path, "conversation": self.conversation_history } with open(filename, 'w', encoding='utf-8') as f: json.dump(conversation_data, f, ensure_ascii=False, indent=2) print(f"โœ… Conversation saved to {filename}") def test_responses(self): """Test the model with predefined inputs""" print("\n" + "="*80) print("๐Ÿงช Testing Model Responses") print("="*80) test_inputs = [ "ใ“ใ‚“ใซใกใฏใ€‚ๆœ€่ฟ‘ใ‚นใƒˆใƒฌใ‚นใ‚’ๆ„Ÿใ˜ใฆใ„ใพใ™ใ€‚", "ไป•ไบ‹ใŒใ†ใพใใ„ใ‹ใชใใฆๆ‚ฉใ‚“ใงใ„ใพใ™ใ€‚", "ไบบ้–“้–ขไฟ‚ใงๅ›ฐใฃใฆใ„ใพใ™ใ€‚ใฉใ†ใ™ใ‚Œใฐใ„ใ„ใงใ—ใ‚‡ใ†ใ‹ใ€‚", "ๅฐ†ๆฅใŒไธๅฎ‰ใง็œ ใ‚Œใพใ›ใ‚“ใ€‚", "่‡ชๅˆ†ใซ่‡ชไฟกใŒๆŒใฆใพใ›ใ‚“ใ€‚", "ๅฎถๆ—ใจใฎ้–ขไฟ‚ใงๆ‚ฉใ‚“ใงใ„ใพใ™ใ€‚", "ๆฏŽๆ—ฅใŒ่พ›ใ„ใงใ™ใ€‚", "่ชฐใซใ‚‚็›ธ่ซ‡ใงใใพใ›ใ‚“ใ€‚" ] print("\nTesting with different temperature settings:\n") for temp in [0, 0.1]: print(f"\n๐ŸŒก๏ธ Temperature: {temp}") print("-"*60) for i, test_input in enumerate(test_inputs[:3], 1): print(f"\n{i}. Input: {test_input}") response = self.generate_response(test_input, temperature=temp, use_context=False) print(f" Response: {response[:200]}...") print() print("="*80) def main(): """Main function""" import argparse parser = argparse.ArgumentParser(description='Chat with fine-tuned counseling model') parser.add_argument('--model_path', type=str, default='./merged_counselor_mode_2b', help='Path to the fine-tuned model') parser.add_argument('--test_only', action='store_true', help='Only run test responses without chat') args = parser.parse_args() # Check if model exists if not os.path.exists(args.model_path): print(f"โŒ Model not found at {args.model_path}") print("\nAvailable models:") for item in os.listdir('.'): if 'model' in item.lower() and os.path.isdir(item): print(f" - {item}") return try: # Initialize chat interface chat = CounselorChatInterface(model_path=args.model_path) if args.test_only: # Run tests only chat.test_responses() else: # Start interactive chat chat.chat() except Exception as e: print(f"โŒ Error: {e}") import traceback traceback.print_exc() if __name__ == "__main__": main()