| """ |
| Interactive Chat Interface for Testing Fine-tuned Japanese Counseling Model |
| """ |
|
|
| import torch |
| from transformers import AutoModelForCausalLM, AutoTokenizer |
| import os |
| import warnings |
| from datetime import datetime |
| import json |
|
|
| warnings.filterwarnings('ignore') |
|
|
| class CounselorChatInterface: |
| def __init__(self, model_path: str = "./merged_counselor_model"): |
| """ |
| Initialize the chat interface with the fine-tuned model |
| |
| Args: |
| model_path: Path to the fine-tuned model |
| """ |
| self.model_path = model_path |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| |
| print("="*80) |
| print("๐ Japanese Counseling Model Chat Interface") |
| print("="*80) |
| print(f"๐ Device: {self.device}") |
| |
| if self.device.type == "cuda": |
| print(f" GPU: {torch.cuda.get_device_name(0)}") |
| print(f" Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB") |
| |
| self.load_model() |
| self.conversation_history = [] |
| |
| def load_model(self): |
| """Load the fine-tuned model and tokenizer""" |
| print(f"\n๐ค Loading model from {self.model_path}...") |
| |
| try: |
| |
| self.tokenizer = AutoTokenizer.from_pretrained( |
| self.model_path, |
| local_files_only=True |
| ) |
| |
| |
| if self.tokenizer.pad_token is None: |
| self.tokenizer.pad_token = self.tokenizer.eos_token |
| |
| |
| self.model = AutoModelForCausalLM.from_pretrained( |
| self.model_path, |
| torch_dtype=torch.float16 if self.device.type == "cuda" else torch.float32, |
| device_map="auto" if self.device.type == "cuda" else None, |
| local_files_only=True, |
| trust_remote_code=True |
| ) |
| |
| self.model.eval() |
| print("โ
Model loaded successfully!") |
| |
| except Exception as e: |
| print(f"โ Error loading model: {e}") |
| print("Trying alternative loading method...") |
| |
| |
| try: |
| self.tokenizer = AutoTokenizer.from_pretrained("gpt2") |
| if self.tokenizer.pad_token is None: |
| self.tokenizer.pad_token = self.tokenizer.eos_token |
| |
| self.model = AutoModelForCausalLM.from_pretrained( |
| self.model_path, |
| torch_dtype=torch.float16 if self.device.type == "cuda" else torch.float32, |
| local_files_only=True |
| ) |
| self.model = self.model.to(self.device) |
| self.model.eval() |
| print("โ
Model loaded with fallback tokenizer!") |
| except Exception as e2: |
| print(f"โ Failed to load model: {e2}") |
| raise |
| |
| def generate_response(self, user_input: str, |
| temperature: float = 0, |
| max_length: int = 200, |
| use_context: bool = True) -> str: |
| """ |
| Generate a counseling response |
| |
| Args: |
| user_input: User's message |
| temperature: Generation temperature (0.1-1.0) |
| max_length: Maximum response length |
| use_context: Whether to use conversation history |
| |
| Returns: |
| Generated response |
| """ |
| |
| if use_context and len(self.conversation_history) > 0: |
| |
| context = "\n".join(self.conversation_history[-4:]) |
| prompt = f"""### Instruction: |
| ใใชใใฏๆใใใใฎใใๅฟ็ใซใฆใณใปใฉใผใงใใ |
| ใฏใฉใคใขใณใใฎๆๆ
ใ็่งฃใใๅ
ฑๆ็ใงๆฏๆด็ใชๅฟ็ญใๆไพใใฆใใ ใใใ |
| |
| ### Context: |
| {context} |
| |
| ### Input: |
| {user_input} |
| |
| ### Response: |
| """ |
| else: |
| prompt = f"""### Instruction: |
| ใใชใใฏๆใใใใฎใใๅฟ็ใซใฆใณใปใฉใผใงใใ |
| ใฏใฉใคใขใณใใฎๆๆ
ใ็่งฃใใๅ
ฑๆ็ใงๆฏๆด็ใชๅฟ็ญใๆไพใใฆใใ ใใใ |
| |
| ### Input: |
| {user_input} |
| |
| ### Response: |
| """ |
| |
| |
| inputs = self.tokenizer( |
| prompt, |
| return_tensors="pt", |
| truncation=True, |
| max_length=512 |
| ) |
| |
| if self.device.type == "cuda": |
| inputs = {k: v.cuda() for k, v in inputs.items()} |
| |
| |
| try: |
| with torch.no_grad(): |
| with torch.cuda.amp.autocast() if self.device.type == "cuda" else torch.autocast("cpu"): |
| outputs = self.model.generate( |
| **inputs, |
| max_new_tokens=max_length, |
| temperature=temperature, |
| do_sample=True, |
| top_p=0.9, |
| top_k=50, |
| repetition_penalty=1.1, |
| pad_token_id=self.tokenizer.pad_token_id, |
| eos_token_id=self.tokenizer.eos_token_id |
| ) |
| |
| |
| full_response = self.tokenizer.decode(outputs[0], skip_special_tokens=True) |
| |
| |
| if "### Response:" in full_response: |
| response = full_response.split("### Response:")[-1].strip() |
| else: |
| response = full_response[len(prompt):].strip() |
| |
| return response |
| |
| except Exception as e: |
| print(f"Error generating response: {e}") |
| return "็ณใ่จณใใใใพใใใๅฟ็ญใฎ็ๆไธญใซใจใฉใผใ็บ็ใใพใใใ" |
| |
| def chat(self): |
| """Start interactive chat session""" |
| print("\n" + "="*80) |
| print("๐ฌ ใใฃใใใ้ๅงใใพใ (Chat session started)") |
| print("="*80) |
| print("Commands:") |
| print(" /quit or /exit - ็ตไบ (Exit)") |
| print(" /clear - ไผ่ฉฑๅฑฅๆญดใใฏใชใข (Clear conversation history)") |
| print(" /save - ไผ่ฉฑใไฟๅญ (Save conversation)") |
| print(" /temp <value> - ๆธฉๅบฆใใฉใกใผใฟใ่จญๅฎ (Set temperature, e.g., /temp 0.8)") |
| print(" /context on/off - ใณใณใใญในใไฝฟ็จใฎๅใๆฟใ (Toggle context usage)") |
| print("-"*80) |
| |
| temperature = 0.1 |
| use_context = True |
| |
| while True: |
| try: |
| |
| user_input = input("\n๐ค You: ").strip() |
| |
| |
| if user_input.lower() in ['/quit', '/exit', '/q']: |
| print("\n๐ ใใใใชใ๏ผ(Goodbye!)") |
| break |
| |
| elif user_input.lower() == '/clear': |
| self.conversation_history = [] |
| print("โ
ไผ่ฉฑๅฑฅๆญดใใฏใชใขใใพใใ (Conversation history cleared)") |
| continue |
| |
| elif user_input.lower() == '/save': |
| self.save_conversation() |
| continue |
| |
| elif user_input.lower().startswith('/temp'): |
| try: |
| temperature = float(user_input.split()[1]) |
| temperature = 0.1 |
| print(f"โ
Temperature set to {temperature}") |
| except: |
| print("โ Invalid temperature. Use: /temp 0.7") |
| continue |
| |
| elif user_input.lower().startswith('/context'): |
| try: |
| setting = user_input.split()[1].lower() |
| use_context = setting == 'on' |
| print(f"โ
Context {'enabled' if use_context else 'disabled'}") |
| except: |
| print("โ Use: /context on or /context off") |
| continue |
| |
| elif user_input.startswith('/'): |
| print("โ Unknown command") |
| continue |
| |
| |
| print("\n๐ค Counselor: ", end="", flush=True) |
| response = self.generate_response( |
| user_input, |
| temperature=temperature, |
| use_context=use_context |
| ) |
| print(response) |
| |
| |
| self.conversation_history.append(f"Client: {user_input}") |
| self.conversation_history.append(f"Counselor: {response}") |
| |
| except KeyboardInterrupt: |
| print("\n\n๐ ใใใใชใ๏ผ(Goodbye!)") |
| break |
| except Exception as e: |
| print(f"\nโ Error: {e}") |
| continue |
| |
| def save_conversation(self): |
| """Save the conversation to a file""" |
| if not self.conversation_history: |
| print("โ No conversation to save") |
| return |
| |
| timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") |
| filename = f"conversation_{timestamp}.json" |
| |
| conversation_data = { |
| "timestamp": timestamp, |
| "model_path": self.model_path, |
| "conversation": self.conversation_history |
| } |
| |
| with open(filename, 'w', encoding='utf-8') as f: |
| json.dump(conversation_data, f, ensure_ascii=False, indent=2) |
| |
| print(f"โ
Conversation saved to {filename}") |
| |
| def test_responses(self): |
| """Test the model with predefined inputs""" |
| print("\n" + "="*80) |
| print("๐งช Testing Model Responses") |
| print("="*80) |
| |
| test_inputs = [ |
| "ใใใซใกใฏใๆ่ฟในใใฌในใๆใใฆใใพใใ", |
| "ไปไบใใใพใใใใชใใฆๆฉใใงใใพใใ", |
| "ไบบ้้ขไฟใงๅฐใฃใฆใใพใใใฉใใใใฐใใใงใใใใใ", |
| "ๅฐๆฅใไธๅฎใง็ ใใพใใใ", |
| "่ชๅใซ่ชไฟกใๆใฆใพใใใ", |
| "ๅฎถๆใจใฎ้ขไฟใงๆฉใใงใใพใใ", |
| "ๆฏๆฅใ่พใใงใใ", |
| "่ชฐใซใ็ธ่ซใงใใพใใใ" |
| ] |
| |
| print("\nTesting with different temperature settings:\n") |
| |
| for temp in [0, 0.1]: |
| print(f"\n๐ก๏ธ Temperature: {temp}") |
| print("-"*60) |
| |
| for i, test_input in enumerate(test_inputs[:3], 1): |
| print(f"\n{i}. Input: {test_input}") |
| response = self.generate_response(test_input, temperature=temp, use_context=False) |
| print(f" Response: {response[:200]}...") |
| print() |
| |
| print("="*80) |
|
|
|
|
| def main(): |
| """Main function""" |
| import argparse |
| |
| parser = argparse.ArgumentParser(description='Chat with fine-tuned counseling model') |
| parser.add_argument('--model_path', type=str, default='./merged_counselor_mode_2b', |
| help='Path to the fine-tuned model') |
| parser.add_argument('--test_only', action='store_true', |
| help='Only run test responses without chat') |
| |
| args = parser.parse_args() |
| |
| |
| if not os.path.exists(args.model_path): |
| print(f"โ Model not found at {args.model_path}") |
| print("\nAvailable models:") |
| for item in os.listdir('.'): |
| if 'model' in item.lower() and os.path.isdir(item): |
| print(f" - {item}") |
| return |
| |
| try: |
| |
| chat = CounselorChatInterface(model_path=args.model_path) |
| |
| if args.test_only: |
| |
| chat.test_responses() |
| else: |
| |
| chat.chat() |
| |
| except Exception as e: |
| print(f"โ Error: {e}") |
| import traceback |
| traceback.print_exc() |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|