Ascol57 commited on
Commit
98879d8
·
verified ·
1 Parent(s): de10e0f

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +76 -0
app.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+ def generate(
4
+ prompt, history, temperature=0.9, max_new_tokens=256, top_p=0.95, repetition_penalty=1.0,
5
+ ):
6
+ temperature = float(temperature)
7
+ if temperature < 1e-2:
8
+ temperature = 1e-2
9
+ top_p = float(top_p)
10
+
11
+ generate_kwargs = dict(
12
+ temperature=temperature,
13
+ max_new_tokens=max_new_tokens,
14
+ top_p=top_p,
15
+ repetition_penalty=repetition_penalty,
16
+ do_sample=True,
17
+ seed=42,
18
+ )
19
+
20
+ formatted_prompt = format_prompt(prompt, history)
21
+
22
+ stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
23
+ output = ""
24
+
25
+ for response in stream:
26
+ output += response.token.text
27
+ yield output
28
+ return output
29
+
30
+
31
+ additional_inputs=[
32
+ gr.Slider(
33
+ label="Temperature",
34
+ value=0.9,
35
+ minimum=0.0,
36
+ maximum=1.0,
37
+ step=0.05,
38
+ interactive=True,
39
+ info="Higher values produce more diverse outputs",
40
+ ),
41
+ gr.Slider(
42
+ label="Max new tokens",
43
+ value=256,
44
+ minimum=0,
45
+ maximum=8192,
46
+ step=64,
47
+ interactive=True,
48
+ info="The maximum numbers of new tokens",
49
+ ),
50
+ gr.Slider(
51
+ label="Top-p (nucleus sampling)",
52
+ value=0.90,
53
+ minimum=0.0,
54
+ maximum=1,
55
+ step=0.05,
56
+ interactive=True,
57
+ info="Higher values sample more low-probability tokens",
58
+ ),
59
+ gr.Slider(
60
+ label="Repetition penalty",
61
+ value=1.2,
62
+ minimum=1.0,
63
+ maximum=2.0,
64
+ step=0.05,
65
+ interactive=True,
66
+ info="Penalize repeated tokens",
67
+ )
68
+ ]
69
+
70
+ with gr.Blocks() as demo:
71
+ gr.ChatInterface(
72
+ generate,
73
+ additional_inputs=additional_inputs,
74
+ )
75
+
76
+ demo.queue().launch(debug=True)