Techiiot's picture
Upload folder using huggingface_hub
27c46c6 verified
"""
Interactive Chat Interface for Testing Fine-tuned Japanese Counseling Model
"""
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import os
import warnings
from datetime import datetime
import json
warnings.filterwarnings('ignore')
class CounselorChatInterface:
def __init__(self, model_path: str = "./merged_counselor_model"):
"""
Initialize the chat interface with the fine-tuned model
Args:
model_path: Path to the fine-tuned model
"""
self.model_path = model_path
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("="*80)
print("๐ŸŽŒ Japanese Counseling Model Chat Interface")
print("="*80)
print(f"๐Ÿ“ Device: {self.device}")
if self.device.type == "cuda":
print(f" GPU: {torch.cuda.get_device_name(0)}")
print(f" Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
self.load_model()
self.conversation_history = []
def load_model(self):
"""Load the fine-tuned model and tokenizer"""
print(f"\n๐Ÿค– Loading model from {self.model_path}...")
try:
# Load tokenizer
self.tokenizer = AutoTokenizer.from_pretrained(
self.model_path,
local_files_only=True
)
# Set padding token if not set
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
# Load model
self.model = AutoModelForCausalLM.from_pretrained(
self.model_path,
torch_dtype=torch.float16 if self.device.type == "cuda" else torch.float32,
device_map="auto" if self.device.type == "cuda" else None,
local_files_only=True,
trust_remote_code=True
)
self.model.eval()
print("โœ… Model loaded successfully!")
except Exception as e:
print(f"โŒ Error loading model: {e}")
print("Trying alternative loading method...")
# Try loading with base tokenizer
try:
self.tokenizer = AutoTokenizer.from_pretrained("gpt2")
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
self.model = AutoModelForCausalLM.from_pretrained(
self.model_path,
torch_dtype=torch.float16 if self.device.type == "cuda" else torch.float32,
local_files_only=True
)
self.model = self.model.to(self.device)
self.model.eval()
print("โœ… Model loaded with fallback tokenizer!")
except Exception as e2:
print(f"โŒ Failed to load model: {e2}")
raise
def generate_response(self, user_input: str,
temperature: float = 0,
max_length: int = 200,
use_context: bool = True) -> str:
"""
Generate a counseling response
Args:
user_input: User's message
temperature: Generation temperature (0.1-1.0)
max_length: Maximum response length
use_context: Whether to use conversation history
Returns:
Generated response
"""
# Format the prompt
if use_context and len(self.conversation_history) > 0:
# Include recent context
context = "\n".join(self.conversation_history[-4:]) # Last 2 exchanges
prompt = f"""### Instruction:
ใ‚ใชใŸใฏๆ€ใ„ใ‚„ใ‚Šใฎใ‚ใ‚‹ๅฟƒ็†ใ‚ซใ‚ฆใƒณใ‚ปใƒฉใƒผใงใ™ใ€‚
ใ‚ฏใƒฉใ‚คใ‚ขใƒณใƒˆใฎๆ„Ÿๆƒ…ใ‚’็†่งฃใ—ใ€ๅ…ฑๆ„Ÿ็š„ใงๆ”ฏๆด็š„ใชๅฟœ็ญ”ใ‚’ๆไพ›ใ—ใฆใใ ใ•ใ„ใ€‚
### Context:
{context}
### Input:
{user_input}
### Response:
"""
else:
prompt = f"""### Instruction:
ใ‚ใชใŸใฏๆ€ใ„ใ‚„ใ‚Šใฎใ‚ใ‚‹ๅฟƒ็†ใ‚ซใ‚ฆใƒณใ‚ปใƒฉใƒผใงใ™ใ€‚
ใ‚ฏใƒฉใ‚คใ‚ขใƒณใƒˆใฎๆ„Ÿๆƒ…ใ‚’็†่งฃใ—ใ€ๅ…ฑๆ„Ÿ็š„ใงๆ”ฏๆด็š„ใชๅฟœ็ญ”ใ‚’ๆไพ›ใ—ใฆใใ ใ•ใ„ใ€‚
### Input:
{user_input}
### Response:
"""
# Tokenize
inputs = self.tokenizer(
prompt,
return_tensors="pt",
truncation=True,
max_length=512
)
if self.device.type == "cuda":
inputs = {k: v.cuda() for k, v in inputs.items()}
# Generate
try:
with torch.no_grad():
with torch.cuda.amp.autocast() if self.device.type == "cuda" else torch.autocast("cpu"):
outputs = self.model.generate(
**inputs,
max_new_tokens=max_length,
temperature=temperature,
do_sample=True,
top_p=0.9,
top_k=50,
repetition_penalty=1.1,
pad_token_id=self.tokenizer.pad_token_id,
eos_token_id=self.tokenizer.eos_token_id
)
# Decode
full_response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
# Extract only the response part
if "### Response:" in full_response:
response = full_response.split("### Response:")[-1].strip()
else:
response = full_response[len(prompt):].strip()
return response
except Exception as e:
print(f"Error generating response: {e}")
return "็”ณใ—่จณใ”ใ–ใ„ใพใ›ใ‚“ใ€‚ๅฟœ็ญ”ใฎ็”Ÿๆˆไธญใซใ‚จใƒฉใƒผใŒ็™บ็”Ÿใ—ใพใ—ใŸใ€‚"
def chat(self):
"""Start interactive chat session"""
print("\n" + "="*80)
print("๐Ÿ’ฌ ใƒใƒฃใƒƒใƒˆใ‚’้–‹ๅง‹ใ—ใพใ™ (Chat session started)")
print("="*80)
print("Commands:")
print(" /quit or /exit - ็ต‚ไบ† (Exit)")
print(" /clear - ไผš่ฉฑๅฑฅๆญดใ‚’ใ‚ฏใƒชใ‚ข (Clear conversation history)")
print(" /save - ไผš่ฉฑใ‚’ไฟๅญ˜ (Save conversation)")
print(" /temp <value> - ๆธฉๅบฆใƒ‘ใƒฉใƒกใƒผใ‚ฟใ‚’่จญๅฎš (Set temperature, e.g., /temp 0.8)")
print(" /context on/off - ใ‚ณใƒณใƒ†ใ‚ญใ‚นใƒˆไฝฟ็”จใฎๅˆ‡ใ‚Šๆ›ฟใˆ (Toggle context usage)")
print("-"*80)
temperature = 0.1
use_context = True
while True:
try:
# Get user input
user_input = input("\n๐Ÿ‘ค You: ").strip()
# Check for commands
if user_input.lower() in ['/quit', '/exit', '/q']:
print("\n๐Ÿ‘‹ ใ•ใ‚ˆใ†ใชใ‚‰๏ผ(Goodbye!)")
break
elif user_input.lower() == '/clear':
self.conversation_history = []
print("โœ… ไผš่ฉฑๅฑฅๆญดใ‚’ใ‚ฏใƒชใ‚ขใ—ใพใ—ใŸ (Conversation history cleared)")
continue
elif user_input.lower() == '/save':
self.save_conversation()
continue
elif user_input.lower().startswith('/temp'):
try:
temperature = float(user_input.split()[1])
temperature = 0.1 # max(0.1, min(, temperature))
print(f"โœ… Temperature set to {temperature}")
except:
print("โŒ Invalid temperature. Use: /temp 0.7")
continue
elif user_input.lower().startswith('/context'):
try:
setting = user_input.split()[1].lower()
use_context = setting == 'on'
print(f"โœ… Context {'enabled' if use_context else 'disabled'}")
except:
print("โŒ Use: /context on or /context off")
continue
elif user_input.startswith('/'):
print("โŒ Unknown command")
continue
# Generate response
print("\n๐Ÿค– Counselor: ", end="", flush=True)
response = self.generate_response(
user_input,
temperature=temperature,
use_context=use_context
)
print(response)
# Add to history
self.conversation_history.append(f"Client: {user_input}")
self.conversation_history.append(f"Counselor: {response}")
except KeyboardInterrupt:
print("\n\n๐Ÿ‘‹ ใ•ใ‚ˆใ†ใชใ‚‰๏ผ(Goodbye!)")
break
except Exception as e:
print(f"\nโŒ Error: {e}")
continue
def save_conversation(self):
"""Save the conversation to a file"""
if not self.conversation_history:
print("โŒ No conversation to save")
return
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
filename = f"conversation_{timestamp}.json"
conversation_data = {
"timestamp": timestamp,
"model_path": self.model_path,
"conversation": self.conversation_history
}
with open(filename, 'w', encoding='utf-8') as f:
json.dump(conversation_data, f, ensure_ascii=False, indent=2)
print(f"โœ… Conversation saved to {filename}")
def test_responses(self):
"""Test the model with predefined inputs"""
print("\n" + "="*80)
print("๐Ÿงช Testing Model Responses")
print("="*80)
test_inputs = [
"ใ“ใ‚“ใซใกใฏใ€‚ๆœ€่ฟ‘ใ‚นใƒˆใƒฌใ‚นใ‚’ๆ„Ÿใ˜ใฆใ„ใพใ™ใ€‚",
"ไป•ไบ‹ใŒใ†ใพใใ„ใ‹ใชใใฆๆ‚ฉใ‚“ใงใ„ใพใ™ใ€‚",
"ไบบ้–“้–ขไฟ‚ใงๅ›ฐใฃใฆใ„ใพใ™ใ€‚ใฉใ†ใ™ใ‚Œใฐใ„ใ„ใงใ—ใ‚‡ใ†ใ‹ใ€‚",
"ๅฐ†ๆฅใŒไธๅฎ‰ใง็œ ใ‚Œใพใ›ใ‚“ใ€‚",
"่‡ชๅˆ†ใซ่‡ชไฟกใŒๆŒใฆใพใ›ใ‚“ใ€‚",
"ๅฎถๆ—ใจใฎ้–ขไฟ‚ใงๆ‚ฉใ‚“ใงใ„ใพใ™ใ€‚",
"ๆฏŽๆ—ฅใŒ่พ›ใ„ใงใ™ใ€‚",
"่ชฐใซใ‚‚็›ธ่ซ‡ใงใใพใ›ใ‚“ใ€‚"
]
print("\nTesting with different temperature settings:\n")
for temp in [0, 0.1]:
print(f"\n๐ŸŒก๏ธ Temperature: {temp}")
print("-"*60)
for i, test_input in enumerate(test_inputs[:3], 1):
print(f"\n{i}. Input: {test_input}")
response = self.generate_response(test_input, temperature=temp, use_context=False)
print(f" Response: {response[:200]}...")
print()
print("="*80)
def main():
"""Main function"""
import argparse
parser = argparse.ArgumentParser(description='Chat with fine-tuned counseling model')
parser.add_argument('--model_path', type=str, default='./merged_counselor_mode_2b',
help='Path to the fine-tuned model')
parser.add_argument('--test_only', action='store_true',
help='Only run test responses without chat')
args = parser.parse_args()
# Check if model exists
if not os.path.exists(args.model_path):
print(f"โŒ Model not found at {args.model_path}")
print("\nAvailable models:")
for item in os.listdir('.'):
if 'model' in item.lower() and os.path.isdir(item):
print(f" - {item}")
return
try:
# Initialize chat interface
chat = CounselorChatInterface(model_path=args.model_path)
if args.test_only:
# Run tests only
chat.test_responses()
else:
# Start interactive chat
chat.chat()
except Exception as e:
print(f"โŒ Error: {e}")
import traceback
traceback.print_exc()
if __name__ == "__main__":
main()