turkfork commited on
Commit
61f8ea3
·
verified ·
1 Parent(s): 0761a71

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +48 -22
app.py CHANGED
@@ -5,7 +5,6 @@ from transformers import AutoModelForCausalLM, AutoTokenizer
5
  model = AutoModelForCausalLM.from_pretrained("microsoft/phi-2")
6
  tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-2")
7
 
8
- # Set pad_token as eos_token
9
  tokenizer.pad_token = tokenizer.eos_token
10
 
11
  # Load training protocol from file
@@ -16,39 +15,66 @@ except FileNotFoundError:
16
  training_protocol = "You are AeroAI, a helpful, friendly, and slightly humorous educational assistant."
17
  print("⚠ training-protocol.aero not found, using default protocol.")
18
 
19
- # Chatbot function
20
- def chatbot(input_text):
21
- # Combine protocol with user input
22
- prompt = f"{training_protocol}\n\nUser: {input_text}\nAeroAI:"
 
 
 
 
 
 
 
 
 
23
 
 
 
 
 
 
 
 
 
 
24
  inputs = tokenizer(
25
- prompt,
26
  return_tensors="pt",
27
  padding=True,
28
  truncation=True
29
  )
 
30
  outputs = model.generate(
31
  **inputs,
32
- max_length=2000000000000,
33
  do_sample=True,
34
  pad_token_id=tokenizer.pad_token_id
35
  )
 
36
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
37
-
38
- # Optional: cut out the protocol part from the output
39
  if "AeroAI:" in response:
40
  response = response.split("AeroAI:")[-1].strip()
41
-
42
- return response
43
-
44
- # Gradio Interface
45
- iface = gr.Interface(
46
- fn=chatbot,
47
- inputs="text",
48
- outputs="text",
49
- title="AeroAI (Phi-2)",
50
- description="By Blacklink Education in collaboration with Microsoft."
51
- )
52
-
53
- # Launch the interface
 
 
 
 
 
 
 
 
 
54
  iface.launch()
 
5
  model = AutoModelForCausalLM.from_pretrained("microsoft/phi-2")
6
  tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-2")
7
 
 
8
  tokenizer.pad_token = tokenizer.eos_token
9
 
10
  # Load training protocol from file
 
15
  training_protocol = "You are AeroAI, a helpful, friendly, and slightly humorous educational assistant."
16
  print("⚠ training-protocol.aero not found, using default protocol.")
17
 
18
+ # Thinking messages (rotate through these while generating)
19
+ thinking_messages = [
20
+ "🤔 Thinking deeply about your question...",
21
+ "📚 Flipping through my mental textbooks...",
22
+ "🧮 Running some quick calculations...",
23
+ "💡 Connecting the dots...",
24
+ "🔍 Double-checking my facts..."
25
+ ]
26
+
27
+ # Chatbot function with memory
28
+ def chatbot(user_input, history):
29
+ if history is None:
30
+ history = []
31
 
32
+ # Append user's message to history
33
+ history.append(("User", user_input))
34
+
35
+ # Build the full conversation prompt
36
+ conversation = training_protocol + "\n\n"
37
+ for speaker, text in history:
38
+ conversation += f"{speaker}: {text}\n"
39
+ conversation += "AeroAI:"
40
+
41
  inputs = tokenizer(
42
+ conversation,
43
  return_tensors="pt",
44
  padding=True,
45
  truncation=True
46
  )
47
+
48
  outputs = model.generate(
49
  **inputs,
50
+ max_new_tokens=200, # safer than huge max_length
51
  do_sample=True,
52
  pad_token_id=tokenizer.pad_token_id
53
  )
54
+
55
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
 
 
56
  if "AeroAI:" in response:
57
  response = response.split("AeroAI:")[-1].strip()
58
+
59
+ history.append(("AeroAI", response))
60
+ return history, history
61
+
62
+ # Reset chat
63
+ def reset_chat():
64
+ return [], []
65
+
66
+ # Build Gradio interface
67
+ with gr.Blocks() as iface:
68
+ gr.Markdown("# AeroAI (Phi-2) — By Blacklink Education")
69
+ chatbot_ui = gr.Chatbot()
70
+ user_input = gr.Textbox(placeholder="Type your message...")
71
+ send_button = gr.Button("Send")
72
+ clear_button = gr.Button("Reset Chat")
73
+
74
+ state = gr.State([])
75
+
76
+ send_button.click(chatbot, inputs=[user_input, state], outputs=[chatbot_ui, state])
77
+ user_input.submit(chatbot, inputs=[user_input, state], outputs=[chatbot_ui, state])
78
+ clear_button.click(reset_chat, outputs=[chatbot_ui, state])
79
+
80
  iface.launch()