salmapm commited on
Commit
b0092a1
·
verified ·
1 Parent(s): 5cd3006

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -18
app.py CHANGED
@@ -3,31 +3,26 @@ from transformers import AutoModelForCausalLM, AutoTokenizer
3
  import torch
4
  from huggingface_hub import login
5
 
6
- def load_model(token):
7
- # Log in with the user's token
8
- login(token=token)
9
-
10
- # Define model loading parameters
11
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
12
- model_kwargs = {}
13
-
14
- if torch.cuda.is_available():
15
- model_kwargs = {
16
- 'load_in_8bit': True, # Enable 8-bit quantization if GPU is available
17
- 'device_map': 'auto', # Automatically maps model to available devices
18
- 'low_cpu_mem_usage': True # Reduce CPU memory usage
19
- }
20
 
21
- try:
 
 
 
 
 
 
 
 
 
 
 
22
  tokenizer = AutoTokenizer.from_pretrained("salmapm/llama2_salma")
23
  model = AutoModelForCausalLM.from_pretrained(
24
  "salmapm/llama2_salma",
25
  **model_kwargs
26
  )
27
  model.to(device)
28
- except Exception as e:
29
- raise RuntimeError(f"Model loading failed: {e}")
30
-
31
  return model, tokenizer, device
32
 
33
  def respond(message, history, system_message, max_tokens, temperature, top_p, token):
 
3
  import torch
4
  from huggingface_hub import login
5
 
6
+ model, tokenizer, device = None, None, None
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
+ def load_model(token):
9
+ global model, tokenizer, device
10
+ if model is None:
11
+ login(token=token)
12
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
+ model_kwargs = {}
14
+ if torch.cuda.is_available():
15
+ model_kwargs = {
16
+ 'load_in_8bit': True,
17
+ 'device_map': 'auto',
18
+ 'low_cpu_mem_usage': True
19
+ }
20
  tokenizer = AutoTokenizer.from_pretrained("salmapm/llama2_salma")
21
  model = AutoModelForCausalLM.from_pretrained(
22
  "salmapm/llama2_salma",
23
  **model_kwargs
24
  )
25
  model.to(device)
 
 
 
26
  return model, tokenizer, device
27
 
28
  def respond(message, history, system_message, max_tokens, temperature, top_p, token):