File size: 6,887 Bytes
69f4cd6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
#!/usr/bin/env python3
"""
AgGPT-21 Interactive Chat Interface
A conversational interface for the trained AgGPT-21 model.
"""

import os
import sys
import torch
from AgGPT21 import WordRNN, generate_text, MODEL_FILE, DEVICE

def load_model():
    """Load the trained AgGPT-21 model."""
    if not os.path.exists(MODEL_FILE):
        print(f"❌ Model file '{MODEL_FILE}' not found!")
        print("Please train the model first by running: python AgGPT21.py")
        sys.exit(1)
    
    try:
        print("πŸ”„ Loading AgGPT-21 model...")
        ckpt = torch.load(MODEL_FILE, map_location=DEVICE)
        stoi = ckpt["stoi"]
        itos = ckpt["itos"]
        model = WordRNN(len(stoi))
        model.load_state_dict(ckpt["model_state"])
        model.eval()
        
        # Count parameters
        param_count = sum(p.numel() for p in model.parameters() if p.requires_grad)
        print(f"βœ… Model loaded successfully!")
        print(f"   β€’ Parameters: {param_count:,}")
        print(f"   β€’ Vocabulary size: {len(stoi):,}")
        print(f"   β€’ Device: {DEVICE}")
        print()
        
        return model, stoi, itos
    except Exception as e:
        print(f"❌ Error loading model: {e}")
        sys.exit(1)

def print_banner():
    """Display the AgGPT-21 banner."""
    banner = """
╔══════════════════════════════════════════════════════════════════════╗
β•‘                            πŸ€– AgGPT-21 πŸ€–                           β•‘
β•‘                      Interactive Chat Interface                      β•‘
β•‘                                                                      β•‘
β•‘  β€’ Type your message and press Enter to chat                        β•‘
β•‘  β€’ Use 'quit', 'exit', or 'bye' to end the conversation            β•‘
β•‘  β€’ Use 'help' for more options                                      β•‘
β•šβ•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•
    """
    print(banner)

def print_help():
    """Display help information."""
    help_text = """
πŸ”§ AgGPT-21 Chat Commands:
β€’ Just type your message to chat with the AI
β€’ 'quit', 'exit', 'bye' - End the conversation
β€’ 'help' - Show this help message
β€’ 'clear' - Clear the screen
β€’ 'model' - Show model information
β€’ 'temp X' - Set temperature (e.g., 'temp 0.8')
β€’ 'length X' - Set response length (e.g., 'length 150')

πŸŽ›οΈ Current Settings:
β€’ Temperature: Controls creativity (0.1-2.0, default: 0.9)
β€’ Length: Number of words to generate (50-500, default: 200)
    """
    print(help_text)

def main():
    """Main chat loop."""
    print_banner()
    
    # Load the model
    model, stoi, itos = load_model()
    
    # Chat settings
    temperature = 0.9
    length = 200
    top_k = 50
    top_p = 0.9
    
    print("πŸ’¬ Chat started! Type your message below:")
    print("="*70)
    
    while True:
        try:
            # Get user input
            user_input = input("\nπŸ‘€ You: ").strip()
            
            if not user_input:
                continue
            
            # Handle commands
            user_lower = user_input.lower()
            
            if user_lower in ['quit', 'exit', 'bye']:
                print("\nπŸ‘‹ Goodbye! Thanks for chatting with AgGPT-21!")
                break
            
            elif user_lower == 'help':
                print_help()
                continue
            
            elif user_lower == 'clear':
                os.system('clear' if os.name == 'posix' else 'cls')
                print_banner()
                continue
            
            elif user_lower == 'model':
                param_count = sum(p.numel() for p in model.parameters() if p.requires_grad)
                print(f"\nπŸ€– Model Information:")
                print(f"   β€’ Parameters: {param_count:,}")
                print(f"   β€’ Vocabulary: {len(stoi):,} words")
                print(f"   β€’ Device: {DEVICE}")
                print(f"   β€’ Temperature: {temperature}")
                print(f"   β€’ Max length: {length}")
                continue
            
            elif user_lower.startswith('temp '):
                try:
                    new_temp = float(user_lower.split()[1])
                    if 0.1 <= new_temp <= 2.0:
                        temperature = new_temp
                        print(f"🌑️ Temperature set to {temperature}")
                    else:
                        print("❌ Temperature must be between 0.1 and 2.0")
                except (IndexError, ValueError):
                    print("❌ Invalid temperature. Use: temp 0.8")
                continue
            
            elif user_lower.startswith('length '):
                try:
                    new_length = int(user_lower.split()[1])
                    if 50 <= new_length <= 500:
                        length = new_length
                        print(f"πŸ“ Response length set to {length} words")
                    else:
                        print("❌ Length must be between 50 and 500")
                except (IndexError, ValueError):
                    print("❌ Invalid length. Use: length 150")
                continue
            
            # Generate AI response
            print(f"\nπŸ€– AgGPT-21 (thinking...)", end="", flush=True)
            
            try:
                response = generate_text(
                    model=model,
                    stoi=stoi,
                    itos=itos,
                    prompt=user_input,
                    length=length,
                    temperature=temperature,
                    top_k=top_k,
                    top_p=top_p,
                    device=DEVICE
                )
                
                # Clean up the response (remove the original prompt)
                response_words = response.split()
                prompt_words = user_input.lower().split()
                
                # Find where the new content starts
                if len(response_words) > len(prompt_words):
                    ai_response = " ".join(response_words[len(prompt_words):])
                else:
                    ai_response = response
                
                print(f"\rπŸ€– AgGPT-21: {ai_response}")
                
            except Exception as e:
                print(f"\r❌ Error generating response: {e}")
        
        except KeyboardInterrupt:
            print("\n\nπŸ‘‹ Chat interrupted. Goodbye!")
            break
        except Exception as e:
            print(f"\n❌ Unexpected error: {e}")

if __name__ == "__main__":
    main()