Florian valade commited on
Commit
7e9cb9e
·
1 Parent(s): 97675ea

Update demo to use Gradio

Browse files
Files changed (2) hide show
  1. app.py +134 -165
  2. requirements.txt +2 -2
app.py CHANGED
@@ -1,176 +1,145 @@
1
- # Save this as app.py and run with `streamlit run app.py`
2
- import time
3
- import streamlit as st
4
  import torch
5
  import pandas as pd
6
  import plotly.graph_objects as go
7
- import numpy as np
8
-
9
  from plotly.subplots import make_subplots
10
  from transformers import AutoModelForCausalLM, AutoTokenizer
11
- from typer import clear
12
- from annotated_text import annotated_text
13
-
14
- st.title("Multi-Head LLM Demo")
15
- st.markdown("""This is a demo of a multi-head language model with early exit capabilities.
16
- The model is based on the Phi-2 architecture and model is available here : https://huggingface.co/valcore/Branchy-Phi-2.
17
- \nThe model has four heads, each of which can be exited early based on a threshold. The graph show the depth of early exit for each token (the deeper being the faster) and the time taken to generate each token.
18
- Early exited tokens are annotated with the depth of early exit (with a float smaller than 1, 1 being the deepest)
19
- """)
20
-
21
- def annotated_to_normal(text):
22
- result = ""
23
- for elem in text:
24
- if isinstance(elem, tuple):
25
- result += elem[0]
26
- else:
27
- result += elem
28
- return result
29
 
30
- def generate_next_token(device="cpu"):
31
- print(f"Generating next token from {st.session_state.messages}")
32
- inputs = ""
33
- for message in st.session_state.messages:
34
- inputs += message["role"] + ": " + annotated_to_normal(message["content"]) + "\n"
35
- inputs += "Assistant:"
36
- print(f"Inputs: {inputs}")
37
- inputs = st.session_state.tokenizer.encode(inputs, return_tensors="pt").to(device)
38
- for i in range(50):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  start = time.time()
40
- outputs = st.session_state.model(inputs)
41
  stop = time.time()
42
- next_token_logits = outputs.logits[:, -1, :].squeeze()
43
- next_token_probs = torch.softmax(next_token_logits, dim=-1)
44
- next_token_id = torch.argmax(next_token_probs, dim=-1)
45
- if next_token_id == 50256:
 
46
  break
47
- print(inputs.shape, next_token_id.shape)
48
- inputs = torch.cat([inputs, next_token_id.unsqueeze(0).unsqueeze(-1)], dim=-1)
49
- next_token = st.session_state.tokenizer.decode(next_token_id, return_tensors="pt")
50
- time_taken = stop - start
51
- branch_locations = st.session_state.model.config.branch_locations
52
- print(outputs.head_indices)
53
- if outputs.head_indices in branch_locations:
54
- print(sorted(branch_locations, reverse=True))
55
- early_exit = (branch_locations.index(outputs.head_indices) + 1) / len(branch_locations)
56
- else:
57
- early_exit = 1.25
58
- # Add data to dataframe
59
- new_row = pd.DataFrame({"Time taken (in ms)": [time_taken], "Early exit depth": [early_exit]})
60
- st.session_state.data = pd.concat([st.session_state.data, new_row], ignore_index=True)
61
- yield next_token, early_exit
62
-
63
- @st.cache_resource
64
- def load_model(model_str, tokenizer_str, device="cpu"):
65
- model = AutoModelForCausalLM.from_pretrained(model_str, trust_remote_code=True).to(device)
66
- model.eval()
67
- tokenizer = AutoTokenizer.from_pretrained(tokenizer_str)
68
- return model, tokenizer
69
-
70
- model_str = "valcore/Branchy-Phi-2"
71
- tokenizer_str = "microsoft/Phi-2"
72
- device = "cuda:0" if torch.cuda.is_available() else "cpu"
73
-
74
- if "model" not in st.session_state or "tokenizer" not in st.session_state:
75
- print(f"Loading model on {device}")
76
- st.session_state.model, st.session_state.tokenizer = load_model(model_str, tokenizer_str, device)
77
-
78
- # Initialize chat history and dataframe
79
- if "messages" not in st.session_state:
80
- st.session_state.messages = []
81
- st.session_state.data = pd.DataFrame(columns=["Time taken (in ms)", "Early exit depth"])
82
-
83
- col1, col2 = st.columns([1, 4])
84
-
85
- with col1:
86
- early_exit = st.checkbox("Early exit", value=False)
87
- if early_exit:
88
- st.session_state.model.head_thresholds = [2.506962537765503, 2.656052589416504, 1.924393653869629, 1.4434680938720703]
89
- else:
90
- st.session_state.model.head_thresholds = [10., 10., 10., 10.]
91
- clear_session = st.button("Clear session")
92
- if clear_session:
93
- print("Clearing session")
94
- st.session_state.messages = []
95
- st.session_state.data = pd.DataFrame(columns=["Time taken (in ms)", "Early exit depth"])
96
-
97
- with col2:
98
- # Display chat messages from history on app rerun
99
- for message in st.session_state.messages:
100
- with st.chat_message(message["role"]):
101
- annotated_text(message["content"])
102
-
103
- prompt = st.chat_input("What is up?")
104
- # React to user input
105
- if prompt:
106
- # Display user message in chat message container
107
- with st.chat_message("User"):
108
- st.markdown(prompt)
109
- # Add user message to chat history
110
- st.session_state.messages.append({"role": "User", "content": prompt})
111
-
112
- # Display assistant response in chat message container
113
- with st.chat_message("Assistant"):
114
- response = []
115
- with st.spinner('Running inference...'):
116
- for next_token, early_exit in generate_next_token(device):
117
- if early_exit > 0.0 and early_exit != 1.25:
118
- response.append(tuple((next_token, str(early_exit))))
119
- else:
120
- response.append(next_token)
121
- print(response)
122
- annotated_text(response)
123
-
124
- # Add assistant response to chat history
125
- st.session_state.messages.append({"role": "Assistant", "content": response})
126
-
127
- # Assuming st.session_state.data is a pandas DataFrame
128
- df = st.session_state.data
129
-
130
- # Calculate the max time taken and add a 10% margin
131
- max_time = df["Time taken (in ms)"].max()
132
- time_axis_max = max_time * 1.1 # 10% margin
133
-
134
-
135
- # Create figure with secondary y-axis
136
- fig = make_subplots(specs=[[{"secondary_y": True}]])
137
-
138
- # Add traces
139
- fig.add_trace(
140
- go.Scatter(x=df.index, y=df["Time taken (in ms)"], name="Time taken (in ms)"),
141
- secondary_y=False,
142
- )
143
-
144
- fig.add_trace(
145
- go.Scatter(x=df.index, y=df["Early exit depth"], name="Early exit depth"),
146
- secondary_y=True,
147
- )
148
-
149
- # Set x-axis title
150
- fig.update_xaxes(title_text="Index")
151
-
152
- # Set y-axes titles
153
- fig.update_yaxes(
154
- title_text="Time taken (in ms)",
155
- secondary_y=False,
156
- range=[0, time_axis_max],
157
- tickmode='linear',
158
- dtick=np.ceil(time_axis_max / 5 / 10) * 10 # Round to nearest 10
159
- )
160
- fig.update_yaxes(
161
- title_text="Early exit depth",
162
- secondary_y=True,
163
- range=[0, 1.25],
164
- tickmode='linear',
165
- dtick=0.25
166
- )
167
 
168
- fig.update_layout(
169
- title_text="Time Taken vs Early Exit Depth",
170
- legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1)
171
- )
172
- # Use Streamlit to display the Plotly chart
173
- st.plotly_chart(fig)
 
 
 
 
 
 
 
 
174
 
175
- #st.line_chart(st.session_state.data, x=None, y=["Time taken (in ms)", "Early exit depth"])
176
- print(st.session_state.messages)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
 
 
2
  import torch
3
  import pandas as pd
4
  import plotly.graph_objects as go
 
 
5
  from plotly.subplots import make_subplots
6
  from transformers import AutoModelForCausalLM, AutoTokenizer
7
+ import time
8
+ import numpy as np
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
+ # Load the model and tokenizer
11
+ model_str = "valcore/Branchy-Phi-2"
12
+ tokenizer_str = "microsoft/Phi-2"
13
+ device = "cuda" if torch.cuda.is_available() else "cpu"
14
+
15
+ model = AutoModelForCausalLM.from_pretrained(model_str, trust_remote_code=True).to(device)
16
+ tokenizer = AutoTokenizer.from_pretrained(tokenizer_str)
17
+
18
+ # Initialize dataframe for storing token generation data
19
+ data = pd.DataFrame(columns=["Time taken (in ms)", "Early exit depth", "Token"])
20
+
21
+ # Define thresholds for different epsilon values
22
+ epsilon_thresholds = {
23
+ 0.4: [1.0307843685150146, 0.8693032264709473, 0.6637287139892578, 0.3111608028411865],
24
+ 0.5: [1.505380630493164, 1.5712471008300781, 1.1971790790557861, 0.6908178329467773],
25
+ 0.6: [2.0270779132843018, 1.8969502449035645, 1.4789371490478516, 0.9875392913818359],
26
+ 0.7: [2.506962537765503, 2.656052589416504, 1.924393653869629, 1.4434680938720703],
27
+ 0.8: [3.3786778450012207, 2.568857192993164, 2.5665550231933594, 2.006620407104492],
28
+ 0.9: [3.187114715576172, 3.442272663116455, 2.636230945587158, 2.460529088973999],
29
+ 1.0: [10.0, 10.0, 10.0, 10.0] # Effectively disable early exits
30
+ }
31
+
32
+ # Global variable to control generation
33
+ stop_generation = False
34
+
35
+ def create_plot():
36
+ fig = make_subplots(specs=[[{"secondary_y": True}]])
37
+
38
+ fig.add_trace(
39
+ go.Scatter(
40
+ x=data.index,
41
+ y=data["Time taken (in ms)"],
42
+ name="Time taken (ms)",
43
+ text=data["Token"],
44
+ hovertemplate="<b>Token:</b> %{text}<br><b>Time:</b> %{y:.2f} ms<extra></extra>",
45
+ ),
46
+ secondary_y=False,
47
+ )
48
+
49
+ fig.add_trace(
50
+ go.Scatter(
51
+ x=data.index,
52
+ y=data["Early exit depth"],
53
+ name="Early exit depth",
54
+ text=data["Token"],
55
+ hovertemplate="<b>Token:</b> %{text}<br><b>Depth:</b> %{y:.2f}<extra></extra>",
56
+ ),
57
+ secondary_y=True,
58
+ )
59
+
60
+ fig.update_layout(
61
+ title_text="Token Generation Metrics",
62
+ xaxis_title="Token Index",
63
+ yaxis_title="Time (ms)",
64
+ yaxis2_title="Exit Depth",
65
+ hovermode="closest",
66
+ )
67
+
68
+ fig.update_yaxes(range=[0, 1.1], secondary_y=True)
69
+
70
+ return fig
71
+
72
+ def truncate_context(input_ids, max_length=2048):
73
+ if len(input_ids[0]) > max_length:
74
+ return input_ids[:, -max_length:]
75
+ return input_ids
76
+
77
+ def generate_response(message, chat_history, epsilon):
78
+ global data, stop_generation
79
+ data = pd.DataFrame(columns=["Time taken (in ms)", "Early exit depth", "Token"])
80
+ stop_generation = False
81
+
82
+ # Set model thresholds based on epsilon
83
+ model.head_thresholds = torch.tensor(epsilon_thresholds[epsilon])
84
+
85
+ full_response = ""
86
+ chat_history = chat_history or []
87
+ inputs = tokenizer.encode(message, return_tensors="pt").to(device)
88
+
89
+ while not stop_generation:
90
+ inputs = truncate_context(inputs)
91
  start = time.time()
92
+ outputs = model(inputs)
93
  stop = time.time()
94
+
95
+ next_token_logits = outputs.logits[:, -1, :]
96
+ next_token_id = torch.argmax(next_token_logits, dim=-1)
97
+
98
+ if next_token_id.item() == tokenizer.eos_token_id:
99
  break
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
 
101
+ inputs = torch.cat([inputs, next_token_id.unsqueeze(0)], dim=-1)
102
+ next_token = tokenizer.decode(next_token_id)
103
+ full_response += next_token
104
+
105
+ time_taken = (stop - start) * 1000 # Convert to milliseconds
106
+ branch_locations = model.config.branch_locations
107
+ early_exit = (branch_locations.index(outputs.head_indices) + 1) / len(branch_locations) if outputs.head_indices in branch_locations else 1.0
108
+
109
+ new_row = pd.DataFrame({
110
+ "Time taken (in ms)": [time_taken],
111
+ "Early exit depth": [early_exit],
112
+ "Token": [next_token]
113
+ })
114
+ data = pd.concat([data, new_row], ignore_index=True)
115
 
116
+ new_history = chat_history + [(message, full_response)]
117
+ yield new_history, new_history, gr.update(value=create_plot())
118
+
119
+ def stop_gen():
120
+ global stop_generation
121
+ stop_generation = True
122
+ return gr.update(interactive=False)
123
+
124
+ with gr.Blocks() as demo:
125
+ gr.Markdown("# Multi-Head LLM Demo with Early Exit Capabilities 🤗")
126
+ gr.Markdown("""This is a demo of a multi-head language model with early exit capabilities.
127
+ The model is based on the Phi-2 architecture and is available here: https://huggingface.co/valcore/Branchy-Phi-2.
128
+ The model has four heads, each of which can be exited early based on a threshold. The graph shows the depth of early exit for each token and the time taken to generate each token.
129
+ Use the slider to adjust the early exit threshold. Lower values allow for more early exits, potentially speeding up generation at the cost of accuracy.
130
+ """)
131
+ chatbot = gr.Chatbot()
132
+ msg = gr.Textbox(label="Message")
133
+ epsilon = gr.Slider(minimum=0.4, maximum=1.0, value=0.7, step=0.1, label="Epsilon")
134
+
135
+ with gr.Row():
136
+ send = gr.Button("Send")
137
+ stop = gr.Button("Stop Generation")
138
+
139
+ graph = gr.Plot()
140
+
141
+ send.click(generate_response, inputs=[msg, chatbot, epsilon], outputs=[chatbot, chatbot, graph])
142
+ msg.submit(generate_response, inputs=[msg, chatbot, epsilon], outputs=[chatbot, chatbot, graph])
143
+ stop.click(stop_gen, outputs=[stop])
144
+
145
+ demo.queue().launch()
requirements.txt CHANGED
@@ -1,5 +1,5 @@
1
- streamlit==1.31.0
2
  torch==2.0.1
3
  pandas==2.0.3
4
  transformers==4.36.0
5
- st-annotated-text
 
1
+ gradio==4.32.2
2
  torch==2.0.1
3
  pandas==2.0.3
4
  transformers==4.36.0
5
+ plotly==5.22.0