|
|
|
|
|
""" |
|
|
AgGPT-21 Interactive Chat Interface |
|
|
A conversational interface for the trained AgGPT-21 model. |
|
|
""" |
|
|
|
|
|
import os |
|
|
import sys |
|
|
import torch |
|
|
from AgGPT21 import WordRNN, generate_text, MODEL_FILE, DEVICE |
|
|
|
|
|
def load_model(): |
|
|
"""Load the trained AgGPT-21 model.""" |
|
|
if not os.path.exists(MODEL_FILE): |
|
|
print(f"β Model file '{MODEL_FILE}' not found!") |
|
|
print("Please train the model first by running: python AgGPT21.py") |
|
|
sys.exit(1) |
|
|
|
|
|
try: |
|
|
print("π Loading AgGPT-21 model...") |
|
|
ckpt = torch.load(MODEL_FILE, map_location=DEVICE) |
|
|
stoi = ckpt["stoi"] |
|
|
itos = ckpt["itos"] |
|
|
model = WordRNN(len(stoi)) |
|
|
model.load_state_dict(ckpt["model_state"]) |
|
|
model.eval() |
|
|
|
|
|
|
|
|
param_count = sum(p.numel() for p in model.parameters() if p.requires_grad) |
|
|
print(f"β
Model loaded successfully!") |
|
|
print(f" β’ Parameters: {param_count:,}") |
|
|
print(f" β’ Vocabulary size: {len(stoi):,}") |
|
|
print(f" β’ Device: {DEVICE}") |
|
|
print() |
|
|
|
|
|
return model, stoi, itos |
|
|
except Exception as e: |
|
|
print(f"β Error loading model: {e}") |
|
|
sys.exit(1) |
|
|
|
|
|
def print_banner(): |
|
|
"""Display the AgGPT-21 banner.""" |
|
|
banner = """ |
|
|
ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ |
|
|
β π€ AgGPT-21 π€ β |
|
|
β Interactive Chat Interface β |
|
|
β β |
|
|
β β’ Type your message and press Enter to chat β |
|
|
β β’ Use 'quit', 'exit', or 'bye' to end the conversation β |
|
|
β β’ Use 'help' for more options β |
|
|
ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ |
|
|
""" |
|
|
print(banner) |
|
|
|
|
|
def print_help(): |
|
|
"""Display help information.""" |
|
|
help_text = """ |
|
|
π§ AgGPT-21 Chat Commands: |
|
|
β’ Just type your message to chat with the AI |
|
|
β’ 'quit', 'exit', 'bye' - End the conversation |
|
|
β’ 'help' - Show this help message |
|
|
β’ 'clear' - Clear the screen |
|
|
β’ 'model' - Show model information |
|
|
β’ 'temp X' - Set temperature (e.g., 'temp 0.8') |
|
|
β’ 'length X' - Set response length (e.g., 'length 150') |
|
|
|
|
|
ποΈ Current Settings: |
|
|
β’ Temperature: Controls creativity (0.1-2.0, default: 0.9) |
|
|
β’ Length: Number of words to generate (50-500, default: 200) |
|
|
""" |
|
|
print(help_text) |
|
|
|
|
|
def main(): |
|
|
"""Main chat loop.""" |
|
|
print_banner() |
|
|
|
|
|
|
|
|
model, stoi, itos = load_model() |
|
|
|
|
|
|
|
|
temperature = 0.9 |
|
|
length = 200 |
|
|
top_k = 50 |
|
|
top_p = 0.9 |
|
|
|
|
|
print("π¬ Chat started! Type your message below:") |
|
|
print("="*70) |
|
|
|
|
|
while True: |
|
|
try: |
|
|
|
|
|
user_input = input("\nπ€ You: ").strip() |
|
|
|
|
|
if not user_input: |
|
|
continue |
|
|
|
|
|
|
|
|
user_lower = user_input.lower() |
|
|
|
|
|
if user_lower in ['quit', 'exit', 'bye']: |
|
|
print("\nπ Goodbye! Thanks for chatting with AgGPT-21!") |
|
|
break |
|
|
|
|
|
elif user_lower == 'help': |
|
|
print_help() |
|
|
continue |
|
|
|
|
|
elif user_lower == 'clear': |
|
|
os.system('clear' if os.name == 'posix' else 'cls') |
|
|
print_banner() |
|
|
continue |
|
|
|
|
|
elif user_lower == 'model': |
|
|
param_count = sum(p.numel() for p in model.parameters() if p.requires_grad) |
|
|
print(f"\nπ€ Model Information:") |
|
|
print(f" β’ Parameters: {param_count:,}") |
|
|
print(f" β’ Vocabulary: {len(stoi):,} words") |
|
|
print(f" β’ Device: {DEVICE}") |
|
|
print(f" β’ Temperature: {temperature}") |
|
|
print(f" β’ Max length: {length}") |
|
|
continue |
|
|
|
|
|
elif user_lower.startswith('temp '): |
|
|
try: |
|
|
new_temp = float(user_lower.split()[1]) |
|
|
if 0.1 <= new_temp <= 2.0: |
|
|
temperature = new_temp |
|
|
print(f"π‘οΈ Temperature set to {temperature}") |
|
|
else: |
|
|
print("β Temperature must be between 0.1 and 2.0") |
|
|
except (IndexError, ValueError): |
|
|
print("β Invalid temperature. Use: temp 0.8") |
|
|
continue |
|
|
|
|
|
elif user_lower.startswith('length '): |
|
|
try: |
|
|
new_length = int(user_lower.split()[1]) |
|
|
if 50 <= new_length <= 500: |
|
|
length = new_length |
|
|
print(f"π Response length set to {length} words") |
|
|
else: |
|
|
print("β Length must be between 50 and 500") |
|
|
except (IndexError, ValueError): |
|
|
print("β Invalid length. Use: length 150") |
|
|
continue |
|
|
|
|
|
|
|
|
print(f"\nπ€ AgGPT-21 (thinking...)", end="", flush=True) |
|
|
|
|
|
try: |
|
|
response = generate_text( |
|
|
model=model, |
|
|
stoi=stoi, |
|
|
itos=itos, |
|
|
prompt=user_input, |
|
|
length=length, |
|
|
temperature=temperature, |
|
|
top_k=top_k, |
|
|
top_p=top_p, |
|
|
device=DEVICE |
|
|
) |
|
|
|
|
|
|
|
|
response_words = response.split() |
|
|
prompt_words = user_input.lower().split() |
|
|
|
|
|
|
|
|
if len(response_words) > len(prompt_words): |
|
|
ai_response = " ".join(response_words[len(prompt_words):]) |
|
|
else: |
|
|
ai_response = response |
|
|
|
|
|
print(f"\rπ€ AgGPT-21: {ai_response}") |
|
|
|
|
|
except Exception as e: |
|
|
print(f"\rβ Error generating response: {e}") |
|
|
|
|
|
except KeyboardInterrupt: |
|
|
print("\n\nπ Chat interrupted. Goodbye!") |
|
|
break |
|
|
except Exception as e: |
|
|
print(f"\nβ Unexpected error: {e}") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|