Solarum Asteridion commited on
Commit
fce480e
Β·
verified Β·
1 Parent(s): 89b8edb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -26
app.py CHANGED
@@ -7,7 +7,7 @@ import logging
7
  import gc
8
  import psutil
9
  import os
10
- from huggingface_hub import login
11
 
12
  class MemoryTracker:
13
  @staticmethod
@@ -16,14 +16,17 @@ class MemoryTracker:
16
  memory_gb = process.memory_info().rss / 1024 / 1024 / 1024
17
  return f"{memory_gb:.2f} GB"
18
 
19
- # Configure logging
20
  logging.basicConfig(level=logging.INFO)
21
  logger = logging.getLogger(__name__)
22
 
23
  def setup_huggingface_auth():
24
- """Set up Hugging Face authentication"""
25
- # First try to get token from environment variable
26
- token = "hf_" + 'FsIJpNepbYgvSdnNhTFtifFudneNNDpUj' + "k"
 
 
 
 
27
 
28
  class LocalLLMHandler:
29
  def __init__(self):
@@ -32,46 +35,36 @@ class LocalLLMHandler:
32
  self.memory_tracker = MemoryTracker()
33
 
34
  def load_model(self, model_name="meta-llama/Llama-3.1-8B-Instruct"):
35
- """Load model with optimizations for 16GB RAM"""
36
  try:
37
- # Ensure we're authenticated
38
  if not setup_huggingface_auth():
39
  raise Exception("Hugging Face authentication failed. Please set your token first.")
40
- # Clean up any existing model
41
  if self.model is not None:
42
  del self.model
43
  del self.tokenizer
44
  torch.cuda.empty_cache()
45
  gc.collect()
46
 
47
- # CPU-specific configurations
48
  model_kwargs = {
49
  "device_map": "cpu",
50
- "torch_dtype": torch.bfloat16, # Use bfloat16 for better CPU performance
51
  "low_cpu_mem_usage": True,
52
  }
53
 
54
- logger.info("Loading tokenizer...")
55
  self.tokenizer = AutoTokenizer.from_pretrained(model_name)
56
 
57
- logger.info("Loading model...")
58
  self.model = AutoModelForCausalLM.from_pretrained(
59
  model_name,
60
  **model_kwargs
61
  )
62
 
63
- logger.info("Model loaded successfully")
64
  return True
65
  except Exception as e:
66
  logger.error(f"Error loading model: {e}")
67
- return False
68
 
69
  def generate_response(self, prompt, max_length=500):
70
- """Generate response from the local model"""
71
  try:
72
  inputs = self.tokenizer(prompt, return_tensors="pt")
73
-
74
- # Generate with CPU-friendly parameters
75
  outputs = self.model.generate(
76
  inputs["input_ids"],
77
  max_length=max_length,
@@ -80,12 +73,11 @@ class LocalLLMHandler:
80
  do_sample=True,
81
  pad_token_id=self.tokenizer.eos_token_id
82
  )
83
-
84
  response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
85
  return response
86
  except Exception as e:
87
  logger.error(f"Error generating response: {e}")
88
- return "Sorry, I encountered an error while processing your request."
89
 
90
  def get_current_local_time(timezone_str='UTC'):
91
  try:
@@ -120,14 +112,12 @@ Make it a challenge to guess if you are an AI.
120
  * Be respectful and polite if the user is polite; be rough if the user is rough
121
  """
122
 
123
- # Initialize the model handler
124
  llm_handler = LocalLLMHandler()
125
 
126
  def generate_response(user_message, conversation_history):
127
  current_time, now = get_current_local_time()
128
  current_date = now
129
 
130
- # Construct the complete prompt from conversation history
131
  system_message = generate_system_message(current_time, current_date)
132
  prompt = system_message + "\n\n"
133
 
@@ -139,7 +129,6 @@ def generate_response(user_message, conversation_history):
139
 
140
  prompt += f"User: {user_message}\nAssistant:"
141
 
142
- # Generate response
143
  ai_reply = llm_handler.generate_response(prompt)
144
  logger.info(f"User: {user_message}\nAssistant: {ai_reply}")
145
  return ai_reply
@@ -153,7 +142,6 @@ def chatbot_interface(user_message, history):
153
  history.append({"role": "assistant", "content": ai_response})
154
  return history, history
155
 
156
- # Define Gradio Interface
157
  with gr.Blocks(css="""
158
  @import url('https://fonts.googleapis.com/css2?family=Raleway:wght@400;600&display=swap');
159
 
@@ -174,7 +162,6 @@ body, .gradio-container {
174
  """) as demo:
175
  gr.Markdown("<h1 style='text-align: center; color: #007BFF;'>πŸ€– Local Llama Chatbot πŸ€–</h1>")
176
 
177
- # Load model button
178
  with gr.Row():
179
  load_button = gr.Button("Load Model")
180
  model_status = gr.Textbox(label="Model Status", value="Model not loaded", interactive=False)
@@ -193,8 +180,11 @@ body, .gradio-container {
193
  send = gr.Button("➀", elem_id="send-button")
194
 
195
  def load_model_click():
196
- success = llm_handler.load_model()
197
- return "Model loaded successfully" if success else "Error loading model"
 
 
 
198
 
199
  def update_chat(user_message, history):
200
  if user_message.strip() == "":
 
7
  import gc
8
  import psutil
9
  import os
10
+ from huggingface_hub import login, hf_api
11
 
12
  class MemoryTracker:
13
  @staticmethod
 
16
  memory_gb = process.memory_info().rss / 1024 / 1024 / 1024
17
  return f"{memory_gb:.2f} GB"
18
 
 
19
  logging.basicConfig(level=logging.INFO)
20
  logger = logging.getLogger(__name__)
21
 
22
  def setup_huggingface_auth():
23
+ token = os.environ.get("HF_TOKEN")
24
+ if token is None:
25
+ token = hf_api.HfFolder.get_token()
26
+ if token is None:
27
+ raise Exception("Hugging Face authentication failed. Please set your token.")
28
+ login(token)
29
+ return True
30
 
31
  class LocalLLMHandler:
32
  def __init__(self):
 
35
  self.memory_tracker = MemoryTracker()
36
 
37
  def load_model(self, model_name="meta-llama/Llama-3.1-8B-Instruct"):
 
38
  try:
 
39
  if not setup_huggingface_auth():
40
  raise Exception("Hugging Face authentication failed. Please set your token first.")
 
41
  if self.model is not None:
42
  del self.model
43
  del self.tokenizer
44
  torch.cuda.empty_cache()
45
  gc.collect()
46
 
 
47
  model_kwargs = {
48
  "device_map": "cpu",
49
+ "torch_dtype": torch.bfloat16,
50
  "low_cpu_mem_usage": True,
51
  }
52
 
 
53
  self.tokenizer = AutoTokenizer.from_pretrained(model_name)
54
 
 
55
  self.model = AutoModelForCausalLM.from_pretrained(
56
  model_name,
57
  **model_kwargs
58
  )
59
 
 
60
  return True
61
  except Exception as e:
62
  logger.error(f"Error loading model: {e}")
63
+ return f"Error loading model: {e}"
64
 
65
  def generate_response(self, prompt, max_length=500):
 
66
  try:
67
  inputs = self.tokenizer(prompt, return_tensors="pt")
 
 
68
  outputs = self.model.generate(
69
  inputs["input_ids"],
70
  max_length=max_length,
 
73
  do_sample=True,
74
  pad_token_id=self.tokenizer.eos_token_id
75
  )
 
76
  response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
77
  return response
78
  except Exception as e:
79
  logger.error(f"Error generating response: {e}")
80
+ return f"Error generating response: {str(e)}"
81
 
82
  def get_current_local_time(timezone_str='UTC'):
83
  try:
 
112
  * Be respectful and polite if the user is polite; be rough if the user is rough
113
  """
114
 
 
115
  llm_handler = LocalLLMHandler()
116
 
117
  def generate_response(user_message, conversation_history):
118
  current_time, now = get_current_local_time()
119
  current_date = now
120
 
 
121
  system_message = generate_system_message(current_time, current_date)
122
  prompt = system_message + "\n\n"
123
 
 
129
 
130
  prompt += f"User: {user_message}\nAssistant:"
131
 
 
132
  ai_reply = llm_handler.generate_response(prompt)
133
  logger.info(f"User: {user_message}\nAssistant: {ai_reply}")
134
  return ai_reply
 
142
  history.append({"role": "assistant", "content": ai_response})
143
  return history, history
144
 
 
145
  with gr.Blocks(css="""
146
  @import url('https://fonts.googleapis.com/css2?family=Raleway:wght@400;600&display=swap');
147
 
 
162
  """) as demo:
163
  gr.Markdown("<h1 style='text-align: center; color: #007BFF;'>πŸ€– Local Llama Chatbot πŸ€–</h1>")
164
 
 
165
  with gr.Row():
166
  load_button = gr.Button("Load Model")
167
  model_status = gr.Textbox(label="Model Status", value="Model not loaded", interactive=False)
 
180
  send = gr.Button("➀", elem_id="send-button")
181
 
182
  def load_model_click():
183
+ result = llm_handler.load_model()
184
+ if isinstance(result, str):
185
+ return result
186
+ else:
187
+ return "Model loaded successfully" if result else "Error loading model"
188
 
189
  def update_chat(user_message, history):
190
  if user_message.strip() == "":