salmapm commited on
Commit
38c2f94
·
verified ·
1 Parent(s): 3164fd8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +36 -31
app.py CHANGED
@@ -1,28 +1,35 @@
1
  import gradio as gr
2
  from transformers import AutoModelForCausalLM, AutoTokenizer
3
  import torch
4
- import bitsandbytes as bnb
5
 
6
- # Load the model and tokenizer with 8-bit quantization
7
- tokenizer = AutoTokenizer.from_pretrained("salmapm/llama2_salma")
8
- model = AutoModelForCausalLM.from_pretrained(
9
- "salmapm/llama2_salma",
10
- load_in_8bit=True, # Enable 8-bit quantization
11
- device_map='auto' # Automatically maps model to available devices
12
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
- # Ensure the model is on the correct device (GPU if available)
15
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
- model.to(device)
 
17
 
18
- def respond(
19
- message,
20
- history: list[tuple[str, str]],
21
- system_message,
22
- max_tokens,
23
- temperature,
24
- top_p,
25
- ):
26
  messages = [{"role": "system", "content": system_message}]
27
 
28
  for val in history:
@@ -33,7 +40,6 @@ def respond(
33
 
34
  messages.append({"role": "user", "content": message})
35
 
36
- # Format the prompt for the model
37
  prompt = f"{system_message}\n" + "\n".join(
38
  [f"{msg['role']}: {msg['content']}" for msg in messages]
39
  )
@@ -48,22 +54,21 @@ def respond(
48
  )
49
 
50
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
51
- yield response
52
 
53
- demo = gr.ChatInterface(
54
- respond,
55
- additional_inputs=[
 
 
 
56
  gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
57
  gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
58
  gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
59
- gr.Slider(
60
- minimum=0.1,
61
- maximum=1.0,
62
- value=0.95,
63
- step=0.05,
64
- label="Top-p (nucleus sampling)",
65
- ),
66
  ],
 
67
  )
68
 
69
  if __name__ == "__main__":
 
1
  import gradio as gr
2
  from transformers import AutoModelForCausalLM, AutoTokenizer
3
  import torch
4
+ from huggingface_hub import login, HfApi
5
 
6
+ def load_model(token):
7
+ # Log in with the user's token
8
+ login(token=token)
9
+
10
+ # Load the model and tokenizer
11
+ tokenizer = AutoTokenizer.from_pretrained("salmapm/llama2_salma")
12
+ model = AutoModelForCausalLM.from_pretrained(
13
+ "salmapm/llama2_salma",
14
+ load_in_8bit=True, # Enable 8-bit quantization
15
+ device_map='auto' # Automatically maps model to available devices
16
+ )
17
+
18
+ # Ensure the model is on the correct device (GPU if available)
19
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
20
+ model.to(device)
21
+
22
+ return model, tokenizer, device
23
+
24
+ def respond(message, history, system_message, max_tokens, temperature, top_p, token):
25
+ if not token:
26
+ return "Please provide a Hugging Face token."
27
 
28
+ try:
29
+ model, tokenizer, device = load_model(token)
30
+ except Exception as e:
31
+ return f"An error occurred: {e}"
32
 
 
 
 
 
 
 
 
 
33
  messages = [{"role": "system", "content": system_message}]
34
 
35
  for val in history:
 
40
 
41
  messages.append({"role": "user", "content": message})
42
 
 
43
  prompt = f"{system_message}\n" + "\n".join(
44
  [f"{msg['role']}: {msg['content']}" for msg in messages]
45
  )
 
54
  )
55
 
56
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
57
+ return response
58
 
59
+ # Create the Gradio interface
60
+ demo = gr.Interface(
61
+ fn=respond,
62
+ inputs=[
63
+ gr.Textbox(label="Message"),
64
+ gr.Textbox(label="History (format: (user_message, assistant_response))", lines=2),
65
  gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
66
  gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
67
  gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
68
+ gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)"),
69
+ gr.Textbox(label="Hugging Face Token", type="password") # Token input field
 
 
 
 
 
70
  ],
71
+ outputs="text",
72
  )
73
 
74
  if __name__ == "__main__":