salmapm commited on
Commit
5cd3006
·
verified ·
1 Parent(s): 6faa75f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -11
app.py CHANGED
@@ -1,23 +1,32 @@
1
  import gradio as gr
2
  from transformers import AutoModelForCausalLM, AutoTokenizer
3
  import torch
4
- from huggingface_hub import login, HfApi
5
 
6
  def load_model(token):
7
  # Log in with the user's token
8
  login(token=token)
9
 
10
- # Load the model and tokenizer
11
- tokenizer = AutoTokenizer.from_pretrained("salmapm/llama2_salma")
12
- model = AutoModelForCausalLM.from_pretrained(
13
- "salmapm/llama2_salma",
14
- load_in_8bit=True, # Enable 8-bit quantization
15
- device_map='auto' # Automatically maps model to available devices
16
- )
17
-
18
- # Ensure the model is on the correct device (GPU if available)
19
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
20
- model.to(device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
  return model, tokenizer, device
23
 
 
1
  import gradio as gr
2
  from transformers import AutoModelForCausalLM, AutoTokenizer
3
  import torch
4
+ from huggingface_hub import login
5
 
6
  def load_model(token):
7
  # Log in with the user's token
8
  login(token=token)
9
 
10
+ # Define model loading parameters
 
 
 
 
 
 
 
 
11
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
12
+ model_kwargs = {}
13
+
14
+ if torch.cuda.is_available():
15
+ model_kwargs = {
16
+ 'load_in_8bit': True, # Enable 8-bit quantization if GPU is available
17
+ 'device_map': 'auto', # Automatically maps model to available devices
18
+ 'low_cpu_mem_usage': True # Reduce CPU memory usage
19
+ }
20
+
21
+ try:
22
+ tokenizer = AutoTokenizer.from_pretrained("salmapm/llama2_salma")
23
+ model = AutoModelForCausalLM.from_pretrained(
24
+ "salmapm/llama2_salma",
25
+ **model_kwargs
26
+ )
27
+ model.to(device)
28
+ except Exception as e:
29
+ raise RuntimeError(f"Model loading failed: {e}")
30
 
31
  return model, tokenizer, device
32