File size: 6,887 Bytes
69f4cd6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 |
#!/usr/bin/env python3
"""
AgGPT-21 Interactive Chat Interface
A conversational interface for the trained AgGPT-21 model.
"""
import os
import sys
import torch
from AgGPT21 import WordRNN, generate_text, MODEL_FILE, DEVICE
def load_model():
"""Load the trained AgGPT-21 model."""
if not os.path.exists(MODEL_FILE):
print(f"β Model file '{MODEL_FILE}' not found!")
print("Please train the model first by running: python AgGPT21.py")
sys.exit(1)
try:
print("π Loading AgGPT-21 model...")
ckpt = torch.load(MODEL_FILE, map_location=DEVICE)
stoi = ckpt["stoi"]
itos = ckpt["itos"]
model = WordRNN(len(stoi))
model.load_state_dict(ckpt["model_state"])
model.eval()
# Count parameters
param_count = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"β
Model loaded successfully!")
print(f" β’ Parameters: {param_count:,}")
print(f" β’ Vocabulary size: {len(stoi):,}")
print(f" β’ Device: {DEVICE}")
print()
return model, stoi, itos
except Exception as e:
print(f"β Error loading model: {e}")
sys.exit(1)
def print_banner():
"""Display the AgGPT-21 banner."""
banner = """
ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
β π€ AgGPT-21 π€ β
β Interactive Chat Interface β
β β
β β’ Type your message and press Enter to chat β
β β’ Use 'quit', 'exit', or 'bye' to end the conversation β
β β’ Use 'help' for more options β
ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
"""
print(banner)
def print_help():
"""Display help information."""
help_text = """
π§ AgGPT-21 Chat Commands:
β’ Just type your message to chat with the AI
β’ 'quit', 'exit', 'bye' - End the conversation
β’ 'help' - Show this help message
β’ 'clear' - Clear the screen
β’ 'model' - Show model information
β’ 'temp X' - Set temperature (e.g., 'temp 0.8')
β’ 'length X' - Set response length (e.g., 'length 150')
ποΈ Current Settings:
β’ Temperature: Controls creativity (0.1-2.0, default: 0.9)
β’ Length: Number of words to generate (50-500, default: 200)
"""
print(help_text)
def main():
"""Main chat loop."""
print_banner()
# Load the model
model, stoi, itos = load_model()
# Chat settings
temperature = 0.9
length = 200
top_k = 50
top_p = 0.9
print("π¬ Chat started! Type your message below:")
print("="*70)
while True:
try:
# Get user input
user_input = input("\nπ€ You: ").strip()
if not user_input:
continue
# Handle commands
user_lower = user_input.lower()
if user_lower in ['quit', 'exit', 'bye']:
print("\nπ Goodbye! Thanks for chatting with AgGPT-21!")
break
elif user_lower == 'help':
print_help()
continue
elif user_lower == 'clear':
os.system('clear' if os.name == 'posix' else 'cls')
print_banner()
continue
elif user_lower == 'model':
param_count = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"\nπ€ Model Information:")
print(f" β’ Parameters: {param_count:,}")
print(f" β’ Vocabulary: {len(stoi):,} words")
print(f" β’ Device: {DEVICE}")
print(f" β’ Temperature: {temperature}")
print(f" β’ Max length: {length}")
continue
elif user_lower.startswith('temp '):
try:
new_temp = float(user_lower.split()[1])
if 0.1 <= new_temp <= 2.0:
temperature = new_temp
print(f"π‘οΈ Temperature set to {temperature}")
else:
print("β Temperature must be between 0.1 and 2.0")
except (IndexError, ValueError):
print("β Invalid temperature. Use: temp 0.8")
continue
elif user_lower.startswith('length '):
try:
new_length = int(user_lower.split()[1])
if 50 <= new_length <= 500:
length = new_length
print(f"π Response length set to {length} words")
else:
print("β Length must be between 50 and 500")
except (IndexError, ValueError):
print("β Invalid length. Use: length 150")
continue
# Generate AI response
print(f"\nπ€ AgGPT-21 (thinking...)", end="", flush=True)
try:
response = generate_text(
model=model,
stoi=stoi,
itos=itos,
prompt=user_input,
length=length,
temperature=temperature,
top_k=top_k,
top_p=top_p,
device=DEVICE
)
# Clean up the response (remove the original prompt)
response_words = response.split()
prompt_words = user_input.lower().split()
# Find where the new content starts
if len(response_words) > len(prompt_words):
ai_response = " ".join(response_words[len(prompt_words):])
else:
ai_response = response
print(f"\rπ€ AgGPT-21: {ai_response}")
except Exception as e:
print(f"\rβ Error generating response: {e}")
except KeyboardInterrupt:
print("\n\nπ Chat interrupted. Goodbye!")
break
except Exception as e:
print(f"\nβ Unexpected error: {e}")
if __name__ == "__main__":
main()
|