Solarum Asteridion commited on
Commit
8b977a9
·
verified ·
1 Parent(s): ca1dedb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +145 -150
app.py CHANGED
@@ -1,36 +1,72 @@
1
- import openai
 
 
2
  import datetime
3
  import pytz
4
- import gradio as gr
5
  import logging
6
- import requests
7
-
8
-
9
- grqapi = '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'
10
 
11
  # Configure logging
12
  logging.basicConfig(level=logging.INFO)
13
  logger = logging.getLogger(__name__)
14
 
15
- # Initialize OpenAI client using secrets.toml
16
- client = openai.OpenAI(
17
- api_key=grqapi, # Accessing the API key from secrets.toml
18
- base_url="https://api.groq.com/openai/v1",
19
- )
20
-
21
- def get_available_models():
22
- url = "https://api.groq.com/openai/v1/models"
23
- headers = {
24
- "Authorization": f"Bearer {grqapi}",
25
- "Content-Type": "application/json"
26
- }
27
- response = requests.get(url, headers=headers)
28
- if response.status_code == 200:
29
- models = [model["id"] for model in response.json()["data"]]
30
- return models
31
- else:
32
- logger.error(f"Error fetching models: {response.status_code} - {response.text}")
33
- return []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
 
35
  def get_current_local_time(timezone_str='UTC'):
36
  try:
@@ -64,148 +100,107 @@ Make it a challenge to guess if you are an AI.
64
  * Be slightly curious, sometimes ask questions
65
  * Be respectful and polite if the user is polite; be rough if the user is rough
66
  """
 
 
 
67
 
68
- def generate_response(user_message, conversation_history, model_name):
69
  current_time, now = get_current_local_time()
70
  current_date = now
71
-
 
72
  system_message = generate_system_message(current_time, current_date)
73
-
74
- messages = [{"role": "system", "content": system_message}]
75
  for message in conversation_history:
76
- # Remove "metadata" if it exists (to be compatible with Groq API)
77
- messages.append({k: v for k, v in message.items() if k != "metadata"})
78
-
79
- messages.append({"role": "user", "content": user_message})
80
-
81
- try:
82
- response = client.chat.completions.create(
83
- model=model_name,
84
- messages=messages,
85
- temperature=1.25,
86
- max_tokens=500, # Adjust as needed
87
- )
88
- ai_reply = response.choices[0].message.content.strip()
89
- logger.info("------------------------------" + user_message + ", " + ai_reply)
90
- return ai_reply
91
- except Exception as e:
92
- logger.error(f"Error generating response: {e}")
93
- return "Sorry, I encountered an error while processing your request."
94
-
95
- def chatbot_interface(user_message, history, model_name):
96
  if history is None:
97
  history = []
98
-
99
- ai_response = generate_response(user_message, history, model_name)
100
  history.append({"role": "user", "content": user_message})
101
  history.append({"role": "assistant", "content": ai_response})
102
- logger.info("Chat history: %s", history) # Corrected logging
103
  return history, history
104
 
105
  # Define Gradio Interface
106
  with gr.Blocks(css="""
107
- /* Import Raleway font from Google Fonts */
108
  @import url('https://fonts.googleapis.com/css2?family=Raleway:wght@400;600&display=swap');
109
 
110
  body, .gradio-container {
111
- font-family: 'Raleway', sans-serif;
112
- background-color: #f5f5f5;
113
- padding: 20px;
114
  }
115
  #chatbot {
116
- height: 600px; /* Increased height for a bigger chat box */
117
- overflow-y: auto;
118
- background-color: #ffffff;
119
- border-radius: 10px;
120
- padding: 10px;
121
- font-size: 16px;
122
- box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1);
123
- }
124
- #textbox {
125
- width: 100%;
126
- border-radius: 25px;
127
- border: 1px solid #ccc;
128
- outline: none;
129
- font-size: 16px;
130
- padding: 10px 20px;
131
- box-sizing: border-box;
132
- }
133
- #send-button {
134
- background-color: #007BFF;
135
- color: white;
136
- border: none;
137
- cursor: pointer;
138
- font-size: 20px;
139
- }
140
- #send-button:hover {
141
- background-color: #0056b3;
142
- }
143
- .message {
144
- margin-bottom: 10px;
145
- }
146
- /* Scrollbar Styling */
147
- #chatbot::-webkit-scrollbar {
148
- width: 8px;
149
- }
150
- #chatbot::-webkit-scrollbar-track {
151
- background: #f1f1f1;
152
- }
153
- #chatbot::-webkit-scrollbar-thumb {
154
- background: #888;
155
- border-radius: 4px;
156
- }
157
- #chatbot::-webkit-scrollbar-thumb:hover {
158
- background: #555;
159
- }
160
- /* Responsive Design */
161
- @media (max-width: 600px) {
162
- #send-button {
163
- width: 40px;
164
- height: 40px;
165
- font-size: 18px;
166
- }
167
- #textbox {
168
- padding: 8px 16px;
169
- }
170
  }
171
  """) as demo:
172
- gr.Markdown("<h1 style='text-align: center; color: #007BFF;'>🤖 Human-like Chatbot 🤖</h1>")
173
-
174
- available_models = get_available_models()
175
- if not available_models:
176
- gr.Markdown("**Error: Could not fetch available models from the API.**")
177
- else:
178
- with gr.Row():
179
- model_dropdown = gr.Dropdown(choices=available_models, label="Select Model", value=available_models[0])
180
- with gr.Row():
181
- with gr.Column(scale=1):
182
- chatbot = gr.Chatbot(label="Chatbot", elem_id="chatbot", type="messages")
183
- with gr.Column(scale=1):
184
- with gr.Row():
185
- msg = gr.Textbox(
186
- placeholder="Type your message here...",
187
- show_label=False,
188
- container=False,
189
- elem_id="textbox"
190
- )
191
- send = gr.Button("➤", elem_id="send-button")
192
-
193
- def update_chat(user_message, history, model_name):
194
- if user_message.strip() == "":
195
- return history, history # Do not process empty messages
196
- history, updated_history = chatbot_interface(user_message, history, model_name)
197
- return history, updated_history, "" # Clear textbox
198
-
199
- send.click(
200
- update_chat,
201
- inputs=[msg, chatbot, model_dropdown],
202
- outputs=[chatbot, chatbot, msg] # Added msg to outputs
203
- )
204
-
205
- msg.submit(
206
- update_chat,
207
- inputs=[msg, chatbot, model_dropdown],
208
- outputs=[chatbot, chatbot, msg] # Added msg to outputs
209
- )
210
-
211
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer
3
+ import gradio as gr
4
  import datetime
5
  import pytz
 
6
  import logging
7
+ import gc
 
 
 
8
 
9
  # Configure logging
10
  logging.basicConfig(level=logging.INFO)
11
  logger = logging.getLogger(__name__)
12
 
13
+ class LocalLLMHandler:
14
+ def __init__(self):
15
+ self.model = None
16
+ self.tokenizer = None
17
+
18
+ def load_model(self, model_name="nvidia/Llama-3.1-Nemotron-70B-Instruct-HF"):
19
+ """Load model with CPU optimizations"""
20
+ try:
21
+ # Clean up any existing model
22
+ if self.model is not None:
23
+ del self.model
24
+ del self.tokenizer
25
+ torch.cuda.empty_cache()
26
+ gc.collect()
27
+
28
+ # CPU-specific configurations
29
+ model_kwargs = {
30
+ "device_map": "cpu",
31
+ "torch_dtype": torch.bfloat16, # Use bfloat16 for better CPU performance
32
+ "low_cpu_mem_usage": True,
33
+ }
34
+
35
+ logger.info("Loading tokenizer...")
36
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name)
37
+
38
+ logger.info("Loading model...")
39
+ self.model = AutoModelForCausalLM.from_pretrained(
40
+ model_name,
41
+ **model_kwargs
42
+ )
43
+
44
+ logger.info("Model loaded successfully")
45
+ return True
46
+ except Exception as e:
47
+ logger.error(f"Error loading model: {e}")
48
+ return False
49
+
50
+ def generate_response(self, prompt, max_length=500):
51
+ """Generate response from the local model"""
52
+ try:
53
+ inputs = self.tokenizer(prompt, return_tensors="pt")
54
+
55
+ # Generate with CPU-friendly parameters
56
+ outputs = self.model.generate(
57
+ inputs["input_ids"],
58
+ max_length=max_length,
59
+ num_return_sequences=1,
60
+ temperature=0.7,
61
+ do_sample=True,
62
+ pad_token_id=self.tokenizer.eos_token_id
63
+ )
64
+
65
+ response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
66
+ return response
67
+ except Exception as e:
68
+ logger.error(f"Error generating response: {e}")
69
+ return "Sorry, I encountered an error while processing your request."
70
 
71
  def get_current_local_time(timezone_str='UTC'):
72
  try:
 
100
  * Be slightly curious, sometimes ask questions
101
  * Be respectful and polite if the user is polite; be rough if the user is rough
102
  """
103
+
104
+ # Initialize the model handler
105
+ llm_handler = LocalLLMHandler()
106
 
107
+ def generate_response(user_message, conversation_history):
108
  current_time, now = get_current_local_time()
109
  current_date = now
110
+
111
+ # Construct the complete prompt from conversation history
112
  system_message = generate_system_message(current_time, current_date)
113
+ prompt = system_message + "\n\n"
114
+
115
  for message in conversation_history:
116
+ if message["role"] == "user":
117
+ prompt += f"User: {message['content']}\n"
118
+ else:
119
+ prompt += f"Assistant: {message['content']}\n"
120
+
121
+ prompt += f"User: {user_message}\nAssistant:"
122
+
123
+ # Generate response
124
+ ai_reply = llm_handler.generate_response(prompt)
125
+ logger.info(f"User: {user_message}\nAssistant: {ai_reply}")
126
+ return ai_reply
127
+
128
+ def chatbot_interface(user_message, history):
 
 
 
 
 
 
 
129
  if history is None:
130
  history = []
131
+
132
+ ai_response = generate_response(user_message, history)
133
  history.append({"role": "user", "content": user_message})
134
  history.append({"role": "assistant", "content": ai_response})
 
135
  return history, history
136
 
137
  # Define Gradio Interface
138
  with gr.Blocks(css="""
 
139
  @import url('https://fonts.googleapis.com/css2?family=Raleway:wght@400;600&display=swap');
140
 
141
  body, .gradio-container {
142
+ font-family: 'Raleway', sans-serif;
143
+ background-color: #f5f5f5;
144
+ padding: 20px;
145
  }
146
  #chatbot {
147
+ height: 600px;
148
+ overflow-y: auto;
149
+ background-color: #ffffff;
150
+ border-radius: 10px;
151
+ padding: 10px;
152
+ font-size: 16px;
153
+ box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
154
  }
155
  """) as demo:
156
+ gr.Markdown("<h1 style='text-align: center; color: #007BFF;'>🤖 Local Llama Chatbot 🤖</h1>")
157
+
158
+ # Load model button
159
+ with gr.Row():
160
+ load_button = gr.Button("Load Model")
161
+ model_status = gr.Textbox(label="Model Status", value="Model not loaded", interactive=False)
162
+
163
+ with gr.Row():
164
+ with gr.Column(scale=1):
165
+ chatbot = gr.Chatbot(label="Chatbot", elem_id="chatbot")
166
+ with gr.Column(scale=1):
167
+ with gr.Row():
168
+ msg = gr.Textbox(
169
+ placeholder="Type your message here...",
170
+ show_label=False,
171
+ container=False,
172
+ elem_id="textbox"
173
+ )
174
+ send = gr.Button("➤", elem_id="send-button")
175
+
176
+ def load_model_click():
177
+ success = llm_handler.load_model()
178
+ return "Model loaded successfully" if success else "Error loading model"
179
+
180
+ def update_chat(user_message, history):
181
+ if user_message.strip() == "":
182
+ return history, history
183
+ if llm_handler.model is None:
184
+ return history + [("Error", "Please load the model first")], history
185
+ history, updated_history = chatbot_interface(user_message, history)
186
+ return history, updated_history, ""
187
+
188
+ load_button.click(
189
+ load_model_click,
190
+ outputs=[model_status]
191
+ )
192
+
193
+ send.click(
194
+ update_chat,
195
+ inputs=[msg, chatbot],
196
+ outputs=[chatbot, chatbot, msg]
197
+ )
198
+
199
+ msg.submit(
200
+ update_chat,
201
+ inputs=[msg, chatbot],
202
+ outputs=[chatbot, chatbot, msg]
203
+ )
204
+
205
+ if __name__ == "__main__":
206
+ demo.launch()