turkfork commited on
Commit
762558e
ยท
verified ยท
1 Parent(s): 5e1b53b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +174 -55
app.py CHANGED
@@ -1,23 +1,34 @@
1
  import gradio as gr
 
2
  from transformers import AutoModelForCausalLM, AutoTokenizer
 
 
 
3
 
4
- # Load Phi-2 model and tokenizer
5
- model = AutoModelForCausalLM.from_pretrained("microsoft/phi-2")
6
- tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-2")
7
 
 
 
 
 
 
 
8
  tokenizer.pad_token = tokenizer.eos_token
9
 
10
  # Load training protocol from file
11
  try:
12
  with open("training-protocol.aero", "r", encoding="utf-8") as f:
13
  training_protocol = f.read().strip()
 
14
  except FileNotFoundError:
15
- training_protocol = (
16
- "You are AeroAI, a helpful, friendly, and slightly humorous educational assistant. "
17
- "If there is any error with finding the file known as the training protocol, resort to this "
18
- 'and do not speak of the "Remorse protocol".'
19
- )
20
- print("โš  ERROR9063: training-protocol.aero not found, using Remorse protocol.")
21
 
22
  # Thinking messages (rotate through these while generating)
23
  thinking_messages = [
@@ -25,60 +36,168 @@ thinking_messages = [
25
  "๐Ÿ“š Flipping through my mental textbooks...",
26
  "๐Ÿงฎ Running some quick calculations...",
27
  "๐Ÿ’ก Connecting the dots...",
28
- "๐Ÿ” Double-checking my facts..."
 
 
 
29
  ]
30
 
31
- # Chatbot function with memory
 
 
 
 
32
  def chatbot(user_input, history):
 
 
 
33
  if history is None:
34
  history = []
35
 
36
- # Append user's message to history
37
- history.append(("User", user_input))
38
-
39
- # Build the full conversation prompt
40
- conversation = training_protocol + "\n\n"
41
- for speaker, text in history:
42
- conversation += f"{speaker}: {text}\n"
43
- conversation += "AeroAI:"
44
-
45
- inputs = tokenizer(
46
- conversation,
47
- return_tensors="pt",
48
- padding=True,
49
- truncation=True
50
- )
51
-
52
- outputs = model.generate(
53
- **inputs,
54
- max_new_tokens=200, # safer than huge max_length
55
- do_sample=True,
56
- pad_token_id=tokenizer.pad_token_id
57
- )
58
-
59
- response = tokenizer.decode(outputs[0], skip_special_tokens=True)
60
- if "AeroAI:" in response:
61
- response = response.split("AeroAI:")[-1].strip()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
 
63
- history.append(("AeroAI", response))
64
- return history, history
65
-
66
- # Reset chat
67
  def reset_chat():
68
- return [], []
69
-
70
- # Build Gradio interface
71
- with gr.Blocks() as iface:
72
- gr.Markdown("# AeroAI (Phi-2) โ€” By Blacklink Labs")
73
- chatbot_ui = gr.Chatbot()
74
- user_input = gr.Textbox(placeholder="Type your message...")
75
- send_button = gr.Button("Send")
76
- clear_button = gr.Button("Reset Chat")
77
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
  state = gr.State([])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
 
80
- send_button.click(chatbot, inputs=[user_input, state], outputs=[chatbot_ui, state])
81
- user_input.submit(chatbot, inputs=[user_input, state], outputs=[chatbot_ui, state])
82
- clear_button.click(reset_chat, outputs=[chatbot_ui, state])
83
-
84
- iface.launch()
 
 
 
 
1
  import gradio as gr
2
+ import torch
3
  from transformers import AutoModelForCausalLM, AutoTokenizer
4
+ import random
5
+ import threading
6
+ import time
7
 
8
+ # Load Phi-2 model and tokenizer with better device handling
9
+ device = "cuda" if torch.cuda.is_available() else "cpu"
10
+ print(f"๐Ÿš€ Loading model on {device}...")
11
 
12
+ model = AutoModelForCausalLM.from_pretrained(
13
+ "microsoft/phi-2",
14
+ torch_dtype=torch.float16 if device == "cuda" else torch.float32,
15
+ device_map="auto" if device == "cuda" else None
16
+ )
17
+ tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-2")
18
  tokenizer.pad_token = tokenizer.eos_token
19
 
20
  # Load training protocol from file
21
  try:
22
  with open("training-protocol.aero", "r", encoding="utf-8") as f:
23
  training_protocol = f.read().strip()
24
+ print("โœ… Training protocol loaded successfully")
25
  except FileNotFoundError:
26
+ training_protocol = (
27
+ "You are AeroAI, a helpful, friendly, and slightly humorous educational assistant. "
28
+ "You provide clear, accurate information while maintaining an engaging conversational style. "
29
+ "Keep responses concise but informative, and feel free to use appropriate emojis."
30
+ )
31
+ print("โš ๏ธ ERROR9063: training-protocol.aero not found, using default protocol.")
32
 
33
  # Thinking messages (rotate through these while generating)
34
  thinking_messages = [
 
36
  "๐Ÿ“š Flipping through my mental textbooks...",
37
  "๐Ÿงฎ Running some quick calculations...",
38
  "๐Ÿ’ก Connecting the dots...",
39
+ "๐Ÿ” Double-checking my facts...",
40
+ "โš™๏ธ Processing your request...",
41
+ "๐ŸŒŸ Crafting the perfect response...",
42
+ "๐ŸŽฏ Focusing on your question..."
43
  ]
44
 
45
+ def show_thinking_message():
46
+ """Return a random thinking message"""
47
+ return random.choice(thinking_messages)
48
+
49
+ # Chatbot function with improved memory and error handling
50
  def chatbot(user_input, history):
51
+ if not user_input.strip():
52
+ return history, history, ""
53
+
54
  if history is None:
55
  history = []
56
 
57
+ try:
58
+ # Show thinking message immediately
59
+ thinking_msg = show_thinking_message()
60
+ temp_history = history + [(user_input, thinking_msg)]
61
+ yield temp_history, temp_history, ""
62
+
63
+ # Build the conversation with better context management
64
+ conversation_context = training_protocol + "\n\nConversation:\n"
65
+
66
+ # Only keep last 10 exchanges to prevent context overflow
67
+ recent_history = history[-10:] if len(history) > 10 else history
68
+
69
+ for user_msg, ai_msg in recent_history:
70
+ conversation_context += f"User: {user_msg}\nAeroAI: {ai_msg}\n"
71
+
72
+ conversation_context += f"User: {user_input}\nAeroAI:"
73
+
74
+ # Tokenize with better handling
75
+ inputs = tokenizer(
76
+ conversation_context,
77
+ return_tensors="pt",
78
+ padding=True,
79
+ truncation=True,
80
+ max_length=1024 # Prevent excessive context length
81
+ ).to(device)
82
+
83
+ # Generate response with better parameters
84
+ with torch.no_grad():
85
+ outputs = model.generate(
86
+ **inputs,
87
+ max_new_tokens=150,
88
+ do_sample=True,
89
+ temperature=0.7,
90
+ top_p=0.9,
91
+ repetition_penalty=1.1,
92
+ pad_token_id=tokenizer.pad_token_id,
93
+ eos_token_id=tokenizer.eos_token_id
94
+ )
95
+
96
+ # Decode response
97
+ full_response = tokenizer.decode(outputs[0], skip_special_tokens=True)
98
+
99
+ # Extract only the new response
100
+ if "AeroAI:" in full_response:
101
+ response = full_response.split("AeroAI:")[-1].strip()
102
+ else:
103
+ response = full_response[len(conversation_context):].strip()
104
+
105
+ # Clean up response (remove incomplete sentences, etc.)
106
+ sentences = response.split('.')
107
+ if len(sentences) > 1 and len(sentences[-1].strip()) < 10:
108
+ response = '.'.join(sentences[:-1]) + '.'
109
+
110
+ # Ensure response isn't empty
111
+ if not response.strip():
112
+ response = "I'm having trouble generating a response. Could you please rephrase your question?"
113
+
114
+ # Update history
115
+ new_history = history + [(user_input, response)]
116
+
117
+ yield new_history, new_history, ""
118
+
119
+ except Exception as e:
120
+ error_response = f"โš ๏ธ Sorry, I encountered an error: {str(e)[:100]}..."
121
+ error_history = history + [(user_input, error_response)]
122
+ yield error_history, error_history, ""
123
 
124
+ # Reset chat function
 
 
 
125
  def reset_chat():
126
+ return [], [], ""
 
 
 
 
 
 
 
 
127
 
128
+ # Build improved Gradio interface
129
+ with gr.Blocks(theme=gr.themes.Soft(), title="AeroAI") as iface:
130
+ gr.Markdown(
131
+ """
132
+ # ๐Ÿš€ AeroAI (Phi-2) โ€” By Blacklink Labs
133
+
134
+ Your friendly AI assistant powered by Microsoft's Phi-2 model. Ask me anything!
135
+ """
136
+ )
137
+
138
+ with gr.Row():
139
+ with gr.Column(scale=4):
140
+ chatbot_ui = gr.Chatbot(
141
+ height=500,
142
+ show_label=False,
143
+ container=True,
144
+ bubble_full_width=False
145
+ )
146
+
147
+ with gr.Column(scale=1):
148
+ gr.Markdown(
149
+ """
150
+ ### ๐Ÿ’ก Tips:
151
+ - Ask questions about any topic
152
+ - Request explanations or tutorials
153
+ - Get help with problems
154
+ - Have casual conversations
155
+
156
+ ### ๐Ÿ› ๏ธ Model Info:
157
+ - Model: Microsoft Phi-2
158
+ - Context: Smart memory management
159
+ - Device: """ + device.upper() + """
160
+ """
161
+ )
162
+
163
+ with gr.Row():
164
+ user_input = gr.Textbox(
165
+ placeholder="Type your message here...",
166
+ container=False,
167
+ scale=4,
168
+ show_label=False
169
+ )
170
+ send_button = gr.Button("Send ๐Ÿ“ค", scale=1, variant="primary")
171
+
172
+ with gr.Row():
173
+ clear_button = gr.Button("๐Ÿ—‘๏ธ Reset Chat", variant="secondary")
174
+
175
+ # State management
176
  state = gr.State([])
177
+
178
+ # Event handlers
179
+ send_event = send_button.click(
180
+ chatbot,
181
+ inputs=[user_input, state],
182
+ outputs=[chatbot_ui, state, user_input]
183
+ )
184
+
185
+ enter_event = user_input.submit(
186
+ chatbot,
187
+ inputs=[user_input, state],
188
+ outputs=[chatbot_ui, state, user_input]
189
+ )
190
+
191
+ clear_button.click(
192
+ reset_chat,
193
+ outputs=[chatbot_ui, state, user_input]
194
+ )
195
 
196
+ if __name__ == "__main__":
197
+ print("๐ŸŒŸ Starting AeroAI interface...")
198
+ iface.launch(
199
+ server_name="0.0.0.0", # Allow external access
200
+ server_port=7860,
201
+ share=False, # Set to True if you want a public link
202
+ show_error=True
203
+ )