arjunanand13 commited on
Commit
40c48ea
·
verified ·
1 Parent(s): eda06cd

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +164 -0
app.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ import spaces
4
+ import torch
5
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, BitsAndBytesConfig
6
+ import gradio as gr
7
+ from threading import Thread
8
+
9
+ MODEL_LIST = ["meta-llama/Meta-Llama-3.1-8B-Instruct"]
10
+ HF_TOKEN = os.environ.get("HF_TOKEN", None)
11
+ MODEL = os.environ.get("MODEL_ID")
12
+
13
+ TITLE = "<h1><center>Meta-Llama3.1-8B (4-bit, 15GB GPU)</center></h1>"
14
+
15
+ PLACEHOLDER = """
16
+ <center>
17
+ <p>Hi! How can I help you today?</p>
18
+ </center>
19
+ """
20
+
21
+ CSS = """
22
+ .duplicate-button {
23
+ margin: auto !important;
24
+ color: white !important;
25
+ background: black !important;
26
+ border-radius: 100vh !important;
27
+ }
28
+ h3 {
29
+ text-align: center;
30
+ }
31
+ """
32
+
33
+ device = "cuda" # for GPU usage or "cpu" for CPU usage
34
+
35
+ # Modified quantization config for 4-bit precision
36
+ quantization_config = BitsAndBytesConfig(
37
+ load_in_4bit=True,
38
+ bnb_4bit_compute_dtype=torch.float16,
39
+ bnb_4bit_use_double_quant=True,
40
+ bnb_4bit_quant_type="nf4"
41
+ )
42
+
43
+ tokenizer = AutoTokenizer.from_pretrained(MODEL)
44
+ model = AutoModelForCausalLM.from_pretrained(
45
+ MODEL,
46
+ quantization_config=quantization_config,
47
+ device_map="auto",
48
+ torch_dtype=torch.float16,
49
+ low_cpu_mem_usage=True
50
+ )
51
+
52
+ def stream_chat(
53
+ message: str,
54
+ history: list,
55
+ system_prompt: str,
56
+ max_new_tokens: int = 1024,
57
+ ):
58
+ print(f'message: {message}')
59
+ print(f'history: {history}')
60
+
61
+ conversation = [
62
+ {"role": "system", "content": system_prompt}
63
+ ]
64
+ for prompt, answer in history:
65
+ conversation.extend([
66
+ {"role": "user", "content": prompt},
67
+ {"role": "assistant", "content": answer},
68
+ ])
69
+
70
+ conversation.append({"role": "user", "content": message})
71
+
72
+ input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt").to(model.device)
73
+
74
+ streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
75
+
76
+ generate_kwargs = dict(
77
+ input_ids=input_ids,
78
+ max_new_tokens=max_new_tokens,
79
+ do_sample=False if temperature == 0 else True,
80
+ top_p=top_p,
81
+ top_k=top_k,
82
+ temperature=temperature,
83
+ repetition_penalty=penalty,
84
+ eos_token_id=[128001,128008,128009],
85
+ streamer=streamer,
86
+ )
87
+
88
+ with torch.no_grad():
89
+ thread = Thread(target=model.generate, kwargs=generate_kwargs)
90
+ thread.start()
91
+
92
+ buffer = ""
93
+ for new_text in streamer:
94
+ buffer += new_text
95
+ yield buffer
96
+
97
+ chatbot = gr.Chatbot(height=600, placeholder=PLACEHOLDER)
98
+
99
+ with gr.Blocks(css=CSS, theme="soft") as demo:
100
+ gr.HTML(TITLE)
101
+ gr.DuplicateButton(value="Duplicate Space for private use", elem_classes="duplicate-button")
102
+ gr.ChatInterface(
103
+ fn=stream_chat,
104
+ chatbot=chatbot,
105
+ fill_height=True,
106
+ additional_inputs_accordion=gr.Accordion(label="⚙️ Parameters", open=False, render=False),
107
+ additional_inputs=[
108
+ gr.Textbox(
109
+ value="You are a helpful assistant",
110
+ label="System Prompt",
111
+ render=False,
112
+ ),
113
+ gr.Slider(
114
+ minimum=0,
115
+ maximum=1,
116
+ step=0.1,
117
+ value=0.8,
118
+ label="Temperature",
119
+ render=False,
120
+ ),
121
+ gr.Slider(
122
+ minimum=128,
123
+ maximum=8192,
124
+ step=1,
125
+ value=1024,
126
+ label="Max new tokens",
127
+ render=False,
128
+ ),
129
+ gr.Slider(
130
+ minimum=0.0,
131
+ maximum=1.0,
132
+ step=0.1,
133
+ value=1.0,
134
+ label="top_p",
135
+ render=False,
136
+ ),
137
+ gr.Slider(
138
+ minimum=1,
139
+ maximum=20,
140
+ step=1,
141
+ value=20,
142
+ label="top_k",
143
+ render=False,
144
+ ),
145
+ gr.Slider(
146
+ minimum=0.0,
147
+ maximum=2.0,
148
+ step=0.1,
149
+ value=1.2,
150
+ label="Repetition penalty",
151
+ render=False,
152
+ ),
153
+ ],
154
+ examples=[
155
+ ["Help me study vocabulary: write a sentence for me to fill in the blank, and I'll try to pick the correct option."],
156
+ ["What are 5 creative things I could do with my kids' art? I don't want to throw them away, but it's also so much clutter."],
157
+ ["Tell me a random fun fact about the Roman Empire."],
158
+ ["Show me a code snippet of a website's sticky header in CSS and JavaScript."],
159
+ ],
160
+ cache_examples=False,
161
+ )
162
+
163
+ if __name__ == "__main__":
164
+ demo.launch()