Spaces:
Running on Zero
Running on Zero
Florian valade commited on
Commit ·
97675ea
1
Parent(s): ad27be1
Update graph to be more understandable
Browse files
app.py
CHANGED
|
@@ -3,7 +3,10 @@ import time
|
|
| 3 |
import streamlit as st
|
| 4 |
import torch
|
| 5 |
import pandas as pd
|
|
|
|
|
|
|
| 6 |
|
|
|
|
| 7 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 8 |
from typer import clear
|
| 9 |
from annotated_text import annotated_text
|
|
@@ -24,14 +27,14 @@ def annotated_to_normal(text):
|
|
| 24 |
result += elem
|
| 25 |
return result
|
| 26 |
|
| 27 |
-
def generate_next_token():
|
| 28 |
print(f"Generating next token from {st.session_state.messages}")
|
| 29 |
inputs = ""
|
| 30 |
for message in st.session_state.messages:
|
| 31 |
inputs += message["role"] + ": " + annotated_to_normal(message["content"]) + "\n"
|
| 32 |
inputs += "Assistant:"
|
| 33 |
print(f"Inputs: {inputs}")
|
| 34 |
-
inputs = st.session_state.tokenizer.encode(inputs, return_tensors="pt")
|
| 35 |
for i in range(50):
|
| 36 |
start = time.time()
|
| 37 |
outputs = st.session_state.model(inputs)
|
|
@@ -51,25 +54,26 @@ def generate_next_token():
|
|
| 51 |
print(sorted(branch_locations, reverse=True))
|
| 52 |
early_exit = (branch_locations.index(outputs.head_indices) + 1) / len(branch_locations)
|
| 53 |
else:
|
| 54 |
-
early_exit =
|
| 55 |
# Add data to dataframe
|
| 56 |
new_row = pd.DataFrame({"Time taken (in ms)": [time_taken], "Early exit depth": [early_exit]})
|
| 57 |
st.session_state.data = pd.concat([st.session_state.data, new_row], ignore_index=True)
|
| 58 |
yield next_token, early_exit
|
| 59 |
|
| 60 |
@st.cache_resource
|
| 61 |
-
def load_model(model_str, tokenizer_str):
|
| 62 |
-
model = AutoModelForCausalLM.from_pretrained(model_str, trust_remote_code=True)
|
| 63 |
model.eval()
|
| 64 |
tokenizer = AutoTokenizer.from_pretrained(tokenizer_str)
|
| 65 |
return model, tokenizer
|
| 66 |
|
| 67 |
model_str = "valcore/Branchy-Phi-2"
|
| 68 |
tokenizer_str = "microsoft/Phi-2"
|
|
|
|
| 69 |
|
| 70 |
if "model" not in st.session_state or "tokenizer" not in st.session_state:
|
| 71 |
-
print("Loading model
|
| 72 |
-
st.session_state.model, st.session_state.tokenizer = load_model(model_str, tokenizer_str)
|
| 73 |
|
| 74 |
# Initialize chat history and dataframe
|
| 75 |
if "messages" not in st.session_state:
|
|
@@ -109,8 +113,8 @@ with col2:
|
|
| 109 |
with st.chat_message("Assistant"):
|
| 110 |
response = []
|
| 111 |
with st.spinner('Running inference...'):
|
| 112 |
-
for next_token, early_exit in generate_next_token():
|
| 113 |
-
if early_exit > 0.0:
|
| 114 |
response.append(tuple((next_token, str(early_exit))))
|
| 115 |
else:
|
| 116 |
response.append(next_token)
|
|
@@ -119,5 +123,54 @@ with col2:
|
|
| 119 |
|
| 120 |
# Add assistant response to chat history
|
| 121 |
st.session_state.messages.append({"role": "Assistant", "content": response})
|
| 122 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 123 |
print(st.session_state.messages)
|
|
|
|
| 3 |
import streamlit as st
|
| 4 |
import torch
|
| 5 |
import pandas as pd
|
| 6 |
+
import plotly.graph_objects as go
|
| 7 |
+
import numpy as np
|
| 8 |
|
| 9 |
+
from plotly.subplots import make_subplots
|
| 10 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 11 |
from typer import clear
|
| 12 |
from annotated_text import annotated_text
|
|
|
|
| 27 |
result += elem
|
| 28 |
return result
|
| 29 |
|
| 30 |
+
def generate_next_token(device="cpu"):
|
| 31 |
print(f"Generating next token from {st.session_state.messages}")
|
| 32 |
inputs = ""
|
| 33 |
for message in st.session_state.messages:
|
| 34 |
inputs += message["role"] + ": " + annotated_to_normal(message["content"]) + "\n"
|
| 35 |
inputs += "Assistant:"
|
| 36 |
print(f"Inputs: {inputs}")
|
| 37 |
+
inputs = st.session_state.tokenizer.encode(inputs, return_tensors="pt").to(device)
|
| 38 |
for i in range(50):
|
| 39 |
start = time.time()
|
| 40 |
outputs = st.session_state.model(inputs)
|
|
|
|
| 54 |
print(sorted(branch_locations, reverse=True))
|
| 55 |
early_exit = (branch_locations.index(outputs.head_indices) + 1) / len(branch_locations)
|
| 56 |
else:
|
| 57 |
+
early_exit = 1.25
|
| 58 |
# Add data to dataframe
|
| 59 |
new_row = pd.DataFrame({"Time taken (in ms)": [time_taken], "Early exit depth": [early_exit]})
|
| 60 |
st.session_state.data = pd.concat([st.session_state.data, new_row], ignore_index=True)
|
| 61 |
yield next_token, early_exit
|
| 62 |
|
| 63 |
@st.cache_resource
|
| 64 |
+
def load_model(model_str, tokenizer_str, device="cpu"):
|
| 65 |
+
model = AutoModelForCausalLM.from_pretrained(model_str, trust_remote_code=True).to(device)
|
| 66 |
model.eval()
|
| 67 |
tokenizer = AutoTokenizer.from_pretrained(tokenizer_str)
|
| 68 |
return model, tokenizer
|
| 69 |
|
| 70 |
model_str = "valcore/Branchy-Phi-2"
|
| 71 |
tokenizer_str = "microsoft/Phi-2"
|
| 72 |
+
device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
| 73 |
|
| 74 |
if "model" not in st.session_state or "tokenizer" not in st.session_state:
|
| 75 |
+
print(f"Loading model on {device}")
|
| 76 |
+
st.session_state.model, st.session_state.tokenizer = load_model(model_str, tokenizer_str, device)
|
| 77 |
|
| 78 |
# Initialize chat history and dataframe
|
| 79 |
if "messages" not in st.session_state:
|
|
|
|
| 113 |
with st.chat_message("Assistant"):
|
| 114 |
response = []
|
| 115 |
with st.spinner('Running inference...'):
|
| 116 |
+
for next_token, early_exit in generate_next_token(device):
|
| 117 |
+
if early_exit > 0.0 and early_exit != 1.25:
|
| 118 |
response.append(tuple((next_token, str(early_exit))))
|
| 119 |
else:
|
| 120 |
response.append(next_token)
|
|
|
|
| 123 |
|
| 124 |
# Add assistant response to chat history
|
| 125 |
st.session_state.messages.append({"role": "Assistant", "content": response})
|
| 126 |
+
|
| 127 |
+
# Assuming st.session_state.data is a pandas DataFrame
|
| 128 |
+
df = st.session_state.data
|
| 129 |
+
|
| 130 |
+
# Calculate the max time taken and add a 10% margin
|
| 131 |
+
max_time = df["Time taken (in ms)"].max()
|
| 132 |
+
time_axis_max = max_time * 1.1 # 10% margin
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
# Create figure with secondary y-axis
|
| 136 |
+
fig = make_subplots(specs=[[{"secondary_y": True}]])
|
| 137 |
+
|
| 138 |
+
# Add traces
|
| 139 |
+
fig.add_trace(
|
| 140 |
+
go.Scatter(x=df.index, y=df["Time taken (in ms)"], name="Time taken (in ms)"),
|
| 141 |
+
secondary_y=False,
|
| 142 |
+
)
|
| 143 |
+
|
| 144 |
+
fig.add_trace(
|
| 145 |
+
go.Scatter(x=df.index, y=df["Early exit depth"], name="Early exit depth"),
|
| 146 |
+
secondary_y=True,
|
| 147 |
+
)
|
| 148 |
+
|
| 149 |
+
# Set x-axis title
|
| 150 |
+
fig.update_xaxes(title_text="Index")
|
| 151 |
+
|
| 152 |
+
# Set y-axes titles
|
| 153 |
+
fig.update_yaxes(
|
| 154 |
+
title_text="Time taken (in ms)",
|
| 155 |
+
secondary_y=False,
|
| 156 |
+
range=[0, time_axis_max],
|
| 157 |
+
tickmode='linear',
|
| 158 |
+
dtick=np.ceil(time_axis_max / 5 / 10) * 10 # Round to nearest 10
|
| 159 |
+
)
|
| 160 |
+
fig.update_yaxes(
|
| 161 |
+
title_text="Early exit depth",
|
| 162 |
+
secondary_y=True,
|
| 163 |
+
range=[0, 1.25],
|
| 164 |
+
tickmode='linear',
|
| 165 |
+
dtick=0.25
|
| 166 |
+
)
|
| 167 |
+
|
| 168 |
+
fig.update_layout(
|
| 169 |
+
title_text="Time Taken vs Early Exit Depth",
|
| 170 |
+
legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1)
|
| 171 |
+
)
|
| 172 |
+
# Use Streamlit to display the Plotly chart
|
| 173 |
+
st.plotly_chart(fig)
|
| 174 |
+
|
| 175 |
+
#st.line_chart(st.session_state.data, x=None, y=["Time taken (in ms)", "Early exit depth"])
|
| 176 |
print(st.session_state.messages)
|