mlabonne commited on
Commit
3f2b118
Β·
verified Β·
1 Parent(s): 5938174

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +98 -46
app.py CHANGED
@@ -1,53 +1,105 @@
1
  import os
2
  import json
 
 
 
 
 
3
  import gradio as gr
4
- from llama_cpp import Llama
5
-
6
- # Get environment variables
7
- model_id = os.getenv('MODEL')
8
- quant = os.getenv('QUANT')
9
- chat_template = os.getenv('CHAT_TEMPLATE')
10
-
11
- # Interface variables
12
- model_name = model_id.split('/')[1].split('-GGUF')[0]
13
- title = f"πŸ‘‘ {model_name}"
14
- description = f"Chat with <a href=\"https://huggingface.co/{model_id}\">{model_name}</a> in GGUF format ({quant})!"
15
-
16
- # Initialize the LLM
17
- llm = Llama(model_path="model.gguf",
18
- n_ctx=32768,
19
- n_threads=2,
20
- chat_format=chat_template)
21
-
22
- # Function for streaming chat completions
23
- def chat_stream_completion(message, history, system_prompt):
24
- messages_prompts = [{"role": "system", "content": system_prompt}]
25
- for human, assistant in history:
26
- messages_prompts.append({"role": "user", "content": human})
27
- messages_prompts.append({"role": "assistant", "content": assistant})
28
- messages_prompts.append({"role": "user", "content": message})
29
-
30
- response = llm.create_chat_completion(
31
- messages=messages_prompts,
32
- stream=True,
33
- stop=["[INST]"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  )
35
- message_repl = ""
36
- for chunk in response:
37
- if len(chunk['choices'][0]["delta"]) != 0 and "content" in chunk['choices'][0]["delta"]:
38
- message_repl = message_repl + chunk['choices'][0]["delta"]["content"]
39
- yield message_repl
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
 
41
- # Gradio chat interface
42
  gr.ChatInterface(
43
- fn=chat_stream_completion,
44
- title=title,
45
- description=description,
46
- additional_inputs=[gr.Textbox("You are helpful assistant.")],
47
- additional_inputs_accordion="πŸ“ System prompt",
48
  examples=[
49
- ["What is a Large Language Model?"],
50
- ["What's 9+2-1?"],
51
- ["Write Python code to print the Fibonacci sequence"]
52
- ]
53
- ).queue().launch(server_name="0.0.0.0")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
  import json
3
+ import subprocess
4
+ from threading import Thread
5
+
6
+ import torch
7
+ import spaces
8
  import gradio as gr
9
+ from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TextIteratorStreamer
10
+
11
+ subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
12
+
13
+ MODEL_ID = os.environ.get("MODEL_ID")
14
+ CHAT_TEMPLATE = os.environ.get("CHAT_TEMPLATE")
15
+ MODEL_NAME = MODEL_ID.split("/")[-1]
16
+ CONTEXT_LENGTH = int(os.environ.get("CONTEXT_LENGTH"))
17
+ COLOR = os.environ.get("COLOR")
18
+ EMOJI = os.environ.get("EMOJI")
19
+ DESCRIPTION = os.environ.get("DESCRIPTION")
20
+
21
+
22
+ @spaces.GPU()
23
+ def predict(message, history, system_prompt, temperature, max_new_tokens, top_k, repetition_penalty, top_p):
24
+ # Format history with a given chat template
25
+ if CHAT_TEMPLATE == "ChatML":
26
+ stop_tokens = ["<|endoftext|>", "<|im_end|>"]
27
+ instruction = '<|im_start|>system\n' + system_prompt + '\n<|im_end|>\n'
28
+ for human, assistant in history:
29
+ instruction += '<|im_start|>user\n' + human + '\n<|im_end|>\n<|im_start|>assistant\n' + assistant
30
+ instruction += '\n<|im_start|>user\n' + message + '\n<|im_end|>\n<|im_start|>assistant\n'
31
+ elif CHAT_TEMPLATE == "Mistral Instruct":
32
+ stop_tokens = ["</s>", "[INST]", "[INST] ", "<s>", "[/INST]", "[/INST] "]
33
+ instruction = '<s>[INST] ' + system_prompt
34
+ for human, assistant in history:
35
+ instruction += human + ' [/INST] ' + assistant + '</s>[INST]'
36
+ instruction += ' ' + message + ' [/INST]'
37
+ else:
38
+ raise Exception("Incorrect chat template, select 'ChatML' or 'Mistral Instruct'")
39
+ print(instruction)
40
+
41
+ streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
42
+ enc = tokenizer([instruction], return_tensors="pt", padding=True, truncation=True)
43
+ input_ids, attention_mask = enc.input_ids, enc.attention_mask
44
+
45
+ if input_ids.shape[1] > CONTEXT_LENGTH:
46
+ input_ids = input_ids[:, -CONTEXT_LENGTH:]
47
+
48
+ generate_kwargs = dict(
49
+ {"input_ids": input_ids.to(device), "attention_mask": attention_mask.to(device)},
50
+ streamer=streamer,
51
+ do_sample=True,
52
+ temperature=temperature,
53
+ max_new_tokens=max_new_tokens,
54
+ top_k=top_k,
55
+ repetition_penalty=repetition_penalty,
56
+ top_p=top_p
57
  )
58
+ t = Thread(target=model.generate, kwargs=generate_kwargs)
59
+ t.start()
60
+ outputs = []
61
+ for new_token in streamer:
62
+ outputs.append(new_token)
63
+ if new_token in stop_tokens:
64
+ break
65
+ yield "".join(outputs)
66
+
67
+
68
+ # Load model
69
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
70
+ quantization_config = BitsAndBytesConfig(
71
+ load_in_4bit=True,
72
+ bnb_4bit_compute_dtype=torch.bfloat16
73
+ )
74
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
75
+ model = AutoModelForCausalLM.from_pretrained(
76
+ MODEL_ID,
77
+ device_map="auto",
78
+ quantization_config=quantization_config,
79
+ attn_implementation="flash_attention_2",
80
+ )
81
 
82
+ # Create Gradio interface
83
  gr.ChatInterface(
84
+ predict,
85
+ title=EMOJI + " " + MODEL_NAME,
86
+ description=DESCRIPTION,
 
 
87
  examples=[
88
+ ["Can you solve the equation 2x + 3 = 11 for x?"],
89
+ ["Write an epic poem about Ancient Rome."],
90
+ ["Who was the first person to walk on the Moon?"],
91
+ ["Use a list comprehension to create a list of squares for numbers from 1 to 10."],
92
+ ["Recommend some popular science fiction books."],
93
+ ["Can you write a short story about a time-traveling detective?"]
94
+ ],
95
+ additional_inputs_accordion=gr.Accordion(label="βš™οΈ Parameters", open=False),
96
+ additional_inputs=[
97
+ gr.Textbox("Perform the task to the best of your ability.", label="System prompt"),
98
+ gr.Slider(0, 1, 0.8, label="Temperature"),
99
+ gr.Slider(128, 4096, 1024, label="Max new tokens"),
100
+ gr.Slider(1, 80, 40, label="Top K sampling"),
101
+ gr.Slider(0, 2, 1.1, label="Repetition penalty"),
102
+ gr.Slider(0, 1, 0.95, label="Top P sampling"),
103
+ ],
104
+ theme=gr.themes.Soft(primary_hue=COLOR),
105
+ ).queue().launch()