Artples commited on
Commit
5e419a8
·
verified ·
1 Parent(s): 0c56476

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +55 -20
app.py CHANGED
@@ -13,25 +13,37 @@ MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
13
 
14
  DESCRIPTION = """\
15
  # L-MChat
16
- This Space demonstrates [L-MChat](https://huggingface.co/collections/Artples/l-mchat-663265a8351231c428318a8f) by L-AI. <br> To select the Model that you want to use please go to the Adavanced Inputs, the Quality-Model (L-MChat-7b) is activated by default.
 
 
 
 
 
 
 
 
 
 
 
17
  """
18
 
19
  if not torch.cuda.is_available():
20
- DESCRIPTION += "\n<p>Running on CPU! This demo does not work on CPU.</p>"
21
 
22
  model_dict = {
23
  "Fast-Model": "Artples/L-MChat-Small",
24
- "Quality-Model": "Artples/L-MChat-7b"
25
  }
26
 
 
27
  @spaces.GPU(enable_queue=True, duration=90)
28
  def generate(
29
  message: str,
30
  chat_history: list[tuple[str, str]],
31
  system_prompt: str,
32
  model_choice: str,
33
- max_new_tokens: int = 1024,
34
- temperature: float = 0.1,
35
  top_p: float = 0.9,
36
  top_k: int = 50,
37
  repetition_penalty: float = 1.2,
@@ -41,45 +53,68 @@ def generate(
41
  tokenizer = AutoTokenizer.from_pretrained(model_id)
42
  tokenizer.use_default_system_prompt = False
43
 
44
- conversation = []
45
  if system_prompt:
46
  conversation.append({"role": "system", "content": system_prompt})
 
47
  for user, assistant in chat_history:
48
- conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
 
 
 
49
  conversation.append({"role": "user", "content": message})
50
 
51
- input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt", add_generation_prompt=True)
 
 
 
 
 
52
  if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
53
  input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
54
- gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
 
 
 
 
55
  input_ids = input_ids.to(model.device)
56
 
57
- streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
 
 
 
 
 
 
58
  generate_kwargs = dict(
59
- {"input_ids": input_ids},
60
  streamer=streamer,
61
- max_new_tokens=max_new_tokens,
62
  do_sample=True,
 
63
  top_p=top_p,
64
  top_k=top_k,
65
- temperature=temperature,
66
- num_beams=1,
67
  repetition_penalty=repetition_penalty,
68
  )
69
- t = Thread(target=model.generate, kwargs=generate_kwargs)
70
- t.start()
71
 
72
- outputs = []
 
 
 
73
  for text in streamer:
74
  outputs.append(text)
75
  yield "".join(outputs)
76
 
 
77
  chat_interface = gr.ChatInterface(
78
- theme='ehristoforu/RE_Theme',
79
  fn=generate,
80
  additional_inputs=[
81
  gr.Textbox(label="System prompt", lines=6),
82
- gr.Radio(["Fast-Model", "Quality-Model"], label="Model", value="Quality-Model"),
 
 
 
 
83
  gr.Slider(
84
  label="Max new tokens",
85
  minimum=1,
@@ -126,7 +161,7 @@ chat_interface = gr.ChatInterface(
126
  ],
127
  )
128
 
129
- with gr.Blocks(css="style.css") as demo:
130
  gr.Markdown(DESCRIPTION)
131
  chat_interface.render()
132
 
 
13
 
14
  DESCRIPTION = """\
15
  # L-MChat
16
+
17
+ This Space demonstrates **L-MChat**, a pair of chat-optimized language models:
18
+
19
+ - **Fast-Model**: `Artples/L-MChat-Small`
20
+ - **Quality-Model**: `Artples/L-MChat-7b`
21
+
22
+ By default the **Quality-Model** is selected. You can switch to the Fast-Model if you want
23
+ lower latency at the cost of quality.
24
+
25
+ Use the *System prompt* field to steer the assistant’s behavior (for example:
26
+ “Act as a helpful programming tutor”). The sliders allow you to configure the
27
+ generation parameters.
28
  """
29
 
30
  if not torch.cuda.is_available():
31
+ DESCRIPTION += "\n\n<p>Running on CPU! This demo does not work on CPU.</p>"
32
 
33
  model_dict = {
34
  "Fast-Model": "Artples/L-MChat-Small",
35
+ "Quality-Model": "Artples/L-MChat-7b",
36
  }
37
 
38
+
39
  @spaces.GPU(enable_queue=True, duration=90)
40
  def generate(
41
  message: str,
42
  chat_history: list[tuple[str, str]],
43
  system_prompt: str,
44
  model_choice: str,
45
+ max_new_tokens: int = DEFAULT_MAX_NEW_TOKENS,
46
+ temperature: float = 0.6,
47
  top_p: float = 0.9,
48
  top_k: int = 50,
49
  repetition_penalty: float = 1.2,
 
53
  tokenizer = AutoTokenizer.from_pretrained(model_id)
54
  tokenizer.use_default_system_prompt = False
55
 
56
+ conversation: list[dict] = []
57
  if system_prompt:
58
  conversation.append({"role": "system", "content": system_prompt})
59
+
60
  for user, assistant in chat_history:
61
+ conversation.append({"role": "user", "content": user})
62
+ if assistant is not None:
63
+ conversation.append({"role": "assistant", "content": assistant})
64
+
65
  conversation.append({"role": "user", "content": message})
66
 
67
+ input_ids = tokenizer.apply_chat_template(
68
+ conversation,
69
+ return_tensors="pt",
70
+ add_generation_prompt=True,
71
+ )
72
+
73
  if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
74
  input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
75
+ gr.Warning(
76
+ f"Trimmed input from conversation as it was longer than "
77
+ f"{MAX_INPUT_TOKEN_LENGTH} tokens."
78
+ )
79
+
80
  input_ids = input_ids.to(model.device)
81
 
82
+ streamer = TextIteratorStreamer(
83
+ tokenizer,
84
+ timeout=10.0,
85
+ skip_prompt=True,
86
+ skip_special_tokens=True,
87
+ )
88
+
89
  generate_kwargs = dict(
90
+ input_ids=input_ids,
91
  streamer=streamer,
92
+ max_new_tokens=min(max_new_tokens, MAX_MAX_NEW_TOKENS),
93
  do_sample=True,
94
+ temperature=temperature,
95
  top_p=top_p,
96
  top_k=top_k,
 
 
97
  repetition_penalty=repetition_penalty,
98
  )
 
 
99
 
100
+ thread = Thread(target=model.generate, kwargs=generate_kwargs)
101
+ thread.start()
102
+
103
+ outputs: list[str] = []
104
  for text in streamer:
105
  outputs.append(text)
106
  yield "".join(outputs)
107
 
108
+
109
  chat_interface = gr.ChatInterface(
 
110
  fn=generate,
111
  additional_inputs=[
112
  gr.Textbox(label="System prompt", lines=6),
113
+ gr.Radio(
114
+ ["Fast-Model", "Quality-Model"],
115
+ label="Model",
116
+ value="Quality-Model",
117
+ ),
118
  gr.Slider(
119
  label="Max new tokens",
120
  minimum=1,
 
161
  ],
162
  )
163
 
164
+ with gr.Blocks(css="styles.css") as demo:
165
  gr.Markdown(DESCRIPTION)
166
  chat_interface.render()
167