Spaces:
Sleeping
Sleeping
Add streaming, disable translation
Browse files* Also upgrade transformers, add sentencepiece
- app.py +110 -81
- requirements.txt +8 -5
app.py
CHANGED
|
@@ -10,6 +10,7 @@ from transformers import (
|
|
| 10 |
AutoModelForCausalLM,
|
| 11 |
AutoModelForSeq2SeqLM,
|
| 12 |
AutoTokenizer,
|
|
|
|
| 13 |
pipeline,
|
| 14 |
set_seed,
|
| 15 |
)
|
|
@@ -41,6 +42,20 @@ def load_model(model_name, task):
|
|
| 41 |
return tokenizer, model
|
| 42 |
|
| 43 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 44 |
class Generator:
|
| 45 |
def __init__(self, model_name, task, desc):
|
| 46 |
self.model_name = model_name
|
|
@@ -52,18 +67,38 @@ class Generator:
|
|
| 52 |
self.load()
|
| 53 |
|
| 54 |
def load(self):
|
| 55 |
-
if not self.
|
| 56 |
print(f"Loading model {self.model_name}")
|
| 57 |
self.tokenizer, self.model = load_model(self.model_name, self.task)
|
| 58 |
-
self.pipeline = pipeline(
|
| 59 |
-
task=self.task,
|
| 60 |
-
model=self.model,
|
| 61 |
-
tokenizer=self.tokenizer,
|
| 62 |
-
device=device,
|
| 63 |
-
)
|
| 64 |
|
| 65 |
-
def
|
| 66 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 67 |
|
| 68 |
|
| 69 |
class GeneratorFactory:
|
|
@@ -82,11 +117,11 @@ class GeneratorFactory:
|
|
| 82 |
"desc": "GPT2 Medium Dutch (book finetune)",
|
| 83 |
"task": "text-generation",
|
| 84 |
},
|
| 85 |
-
{
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
},
|
| 90 |
]
|
| 91 |
for g in GENERATOR_LIST:
|
| 92 |
with st.spinner(text=f"Loading the model {g['desc']} ..."):
|
|
@@ -148,12 +183,13 @@ def main():
|
|
| 148 |
repetition_penalty = st.sidebar.number_input(
|
| 149 |
"Repetition penalty", min_value=0.0, max_value=5.0, value=1.2, step=0.1
|
| 150 |
)
|
| 151 |
-
num_return_sequences =
|
| 152 |
-
|
| 153 |
-
|
|
|
|
| 154 |
seed_placeholder = st.sidebar.empty()
|
| 155 |
if "seed" not in st.session_state:
|
| 156 |
-
print(f"Session state
|
| 157 |
st.session_state["seed"] = 4162549114
|
| 158 |
print(f"Seed is set to: {st.session_state['seed']}")
|
| 159 |
|
|
@@ -218,69 +254,62 @@ and the [Huggingface text generation interface doc](https://huggingface.co/trans
|
|
| 218 |
)
|
| 219 |
|
| 220 |
if st.button("Run"):
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
|
| 224 |
-
|
| 225 |
-
|
| 226 |
-
|
| 227 |
-
|
| 228 |
-
|
| 229 |
-
|
| 230 |
-
|
| 231 |
-
|
| 232 |
-
|
| 233 |
-
|
| 234 |
-
|
| 235 |
-
|
| 236 |
-
|
| 237 |
-
)
|
| 238 |
-
|
| 239 |
-
|
| 240 |
-
|
| 241 |
-
|
| 242 |
-
|
| 243 |
-
|
| 244 |
-
|
| 245 |
-
|
| 246 |
-
|
| 247 |
-
|
| 248 |
-
|
| 249 |
-
|
| 250 |
-
|
| 251 |
-
|
| 252 |
-
|
| 253 |
-
|
| 254 |
-
|
| 255 |
-
|
| 256 |
-
|
| 257 |
-
|
| 258 |
-
|
| 259 |
-
|
| 260 |
-
|
| 261 |
-
|
| 262 |
-
|
| 263 |
-
|
| 264 |
-
|
| 265 |
-
|
| 266 |
-
|
| 267 |
-
|
| 268 |
-
|
| 269 |
-
|
| 270 |
-
|
| 271 |
-
|
| 272 |
-
|
| 273 |
-
|
| 274 |
-
|
| 275 |
-
|
| 276 |
-
|
| 277 |
-
|
| 278 |
-
params["seed"] = seed
|
| 279 |
-
params["prompt"] = st.session_state.text
|
| 280 |
-
params["model"] = generator.model_name
|
| 281 |
-
params_text = json.dumps(params)
|
| 282 |
-
print(params_text)
|
| 283 |
-
st.json(params_text)
|
| 284 |
|
| 285 |
|
| 286 |
if __name__ == "__main__":
|
|
|
|
| 10 |
AutoModelForCausalLM,
|
| 11 |
AutoModelForSeq2SeqLM,
|
| 12 |
AutoTokenizer,
|
| 13 |
+
TextIteratorStreamer,
|
| 14 |
pipeline,
|
| 15 |
set_seed,
|
| 16 |
)
|
|
|
|
| 42 |
return tokenizer, model
|
| 43 |
|
| 44 |
|
| 45 |
+
class StreamlitTextIteratorStreamer(TextIteratorStreamer):
|
| 46 |
+
def __init__(
|
| 47 |
+
self, output_placeholder, tokenizer, skip_prompt=False, **decode_kwargs
|
| 48 |
+
):
|
| 49 |
+
super().__init__(tokenizer, skip_prompt, **decode_kwargs)
|
| 50 |
+
self.output_placeholder = output_placeholder
|
| 51 |
+
self.output_text = ""
|
| 52 |
+
|
| 53 |
+
def on_finalized_text(self, text: str, stream_end: bool = False):
|
| 54 |
+
self.output_text += text
|
| 55 |
+
self.output_placeholder.markdown(self.output_text, unsafe_allow_html=True)
|
| 56 |
+
super().on_finalized_text(text, stream_end)
|
| 57 |
+
|
| 58 |
+
|
| 59 |
class Generator:
|
| 60 |
def __init__(self, model_name, task, desc):
|
| 61 |
self.model_name = model_name
|
|
|
|
| 67 |
self.load()
|
| 68 |
|
| 69 |
def load(self):
|
| 70 |
+
if not self.model:
|
| 71 |
print(f"Loading model {self.model_name}")
|
| 72 |
self.tokenizer, self.model = load_model(self.model_name, self.task)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 73 |
|
| 74 |
+
def generate(self, text: str, streamer=None, **generate_kwargs) -> (str, dict):
|
| 75 |
+
batch_encoded = self.tokenizer(
|
| 76 |
+
text,
|
| 77 |
+
max_length=generate_kwargs["max_length"],
|
| 78 |
+
padding=False,
|
| 79 |
+
truncation=False,
|
| 80 |
+
return_tensors="pt",
|
| 81 |
+
)
|
| 82 |
+
if device != -1:
|
| 83 |
+
batch_encoded.to(f"cuda:{device}")
|
| 84 |
+
logits = self.model.generate(
|
| 85 |
+
batch_encoded["input_ids"],
|
| 86 |
+
attention_mask=batch_encoded["attention_mask"],
|
| 87 |
+
streamer=streamer,
|
| 88 |
+
**generate_kwargs,
|
| 89 |
+
)
|
| 90 |
+
decoded_preds = self.tokenizer.batch_decode(
|
| 91 |
+
logits.cpu().numpy(), skip_special_tokens=False
|
| 92 |
+
)
|
| 93 |
+
|
| 94 |
+
def replace_tokens(pred):
|
| 95 |
+
pred = pred.replace("<pad> ", "").replace("<pad>", "").replace("</s>", "")
|
| 96 |
+
if hasattr(self.tokenizer, "newline_token"):
|
| 97 |
+
pred = pred.replace(self.tokenizer.newline_token, "\n")
|
| 98 |
+
return pred
|
| 99 |
+
|
| 100 |
+
decoded_preds = list(map(replace_tokens, decoded_preds))
|
| 101 |
+
return decoded_preds[0], generate_kwargs
|
| 102 |
|
| 103 |
|
| 104 |
class GeneratorFactory:
|
|
|
|
| 117 |
"desc": "GPT2 Medium Dutch (book finetune)",
|
| 118 |
"task": "text-generation",
|
| 119 |
},
|
| 120 |
+
# {
|
| 121 |
+
# "model_name": "yhavinga/t5-small-24L-ccmatrix-multi",
|
| 122 |
+
# "desc": "Dutch<->English T5 small 24 layers",
|
| 123 |
+
# "task": TRANSLATION_NL_TO_EN,
|
| 124 |
+
# },
|
| 125 |
]
|
| 126 |
for g in GENERATOR_LIST:
|
| 127 |
with st.spinner(text=f"Loading the model {g['desc']} ..."):
|
|
|
|
| 183 |
repetition_penalty = st.sidebar.number_input(
|
| 184 |
"Repetition penalty", min_value=0.0, max_value=5.0, value=1.2, step=0.1
|
| 185 |
)
|
| 186 |
+
num_return_sequences = 1
|
| 187 |
+
# st.sidebar.number_input(
|
| 188 |
+
# "Num return sequences", min_value=1, max_value=5, value=1
|
| 189 |
+
# )
|
| 190 |
seed_placeholder = st.sidebar.empty()
|
| 191 |
if "seed" not in st.session_state:
|
| 192 |
+
print(f"Session state does not contain seed")
|
| 193 |
st.session_state["seed"] = 4162549114
|
| 194 |
print(f"Seed is set to: {st.session_state['seed']}")
|
| 195 |
|
|
|
|
| 254 |
)
|
| 255 |
|
| 256 |
if st.button("Run"):
|
| 257 |
+
memory = psutil.virtual_memory()
|
| 258 |
+
st.subheader("Result")
|
| 259 |
+
container = st.container()
|
| 260 |
+
output_placeholder = container.empty()
|
| 261 |
+
streaming_enabled = True # sampling_mode != "Beam Search" or num_beams == 1
|
| 262 |
+
generator = generators.get_generator(desc=model_desc)
|
| 263 |
+
streamer = (
|
| 264 |
+
StreamlitTextIteratorStreamer(output_placeholder, generator.tokenizer)
|
| 265 |
+
if streaming_enabled
|
| 266 |
+
else None
|
| 267 |
+
)
|
| 268 |
+
set_seed(seed)
|
| 269 |
+
time_start = time.time()
|
| 270 |
+
result = generator.generate(
|
| 271 |
+
text=st.session_state.text, streamer=streamer, **params
|
| 272 |
+
)
|
| 273 |
+
time_end = time.time()
|
| 274 |
+
time_diff = time_end - time_start
|
| 275 |
+
|
| 276 |
+
# for text in result:
|
| 277 |
+
# st.write(text.get("generated_text").replace("\n", " \n"))
|
| 278 |
+
# st.text("*Translation*")
|
| 279 |
+
# translate_params = {
|
| 280 |
+
# "num_return_sequences": 1,
|
| 281 |
+
# "num_beams": 4,
|
| 282 |
+
# "early_stopping": True,
|
| 283 |
+
# "length_penalty": 1.1,
|
| 284 |
+
# "max_length": 200,
|
| 285 |
+
# }
|
| 286 |
+
# text_lines = [
|
| 287 |
+
# "translate Dutch to English: " + t
|
| 288 |
+
# for t in text.get("generated_text").splitlines()
|
| 289 |
+
# ]
|
| 290 |
+
# translated_lines = [
|
| 291 |
+
# t["translation_text"]
|
| 292 |
+
# for t in generators.get_generator(
|
| 293 |
+
# task=TRANSLATION_NL_TO_EN
|
| 294 |
+
# ).get_text(text_lines, **translate_params)
|
| 295 |
+
# ]
|
| 296 |
+
# translation = " \n".join(translated_lines)
|
| 297 |
+
# st.write(translation)
|
| 298 |
+
# st.write("---")
|
| 299 |
+
#
|
| 300 |
+
info = f"""
|
| 301 |
+
---
|
| 302 |
+
*Memory: {memory.total / 10**9:.2f}GB, used: {memory.percent}%, available: {memory.available / 10**9:.2f}GB*
|
| 303 |
+
*Text generated using seed {seed} in {time_diff:.5} seconds*
|
| 304 |
+
"""
|
| 305 |
+
st.write(info)
|
| 306 |
+
|
| 307 |
+
params["seed"] = seed
|
| 308 |
+
params["prompt"] = st.session_state.text
|
| 309 |
+
params["model"] = generator.model_name
|
| 310 |
+
params_text = json.dumps(params)
|
| 311 |
+
# print(params_text)
|
| 312 |
+
st.json(params_text)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 313 |
|
| 314 |
|
| 315 |
if __name__ == "__main__":
|
requirements.txt
CHANGED
|
@@ -1,7 +1,10 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
|
|
|
|
|
|
| 6 |
mtranslate
|
| 7 |
psutil
|
|
|
|
|
|
| 1 |
+
#-f https://download.pytorch.org/whl/torch_stable.html
|
| 2 |
+
-f https://download.pytorch.org/whl/cu116
|
| 3 |
+
-f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
|
| 4 |
+
protobuf<3.20
|
| 5 |
+
streamlit>=1.4.0,<=1.10.0
|
| 6 |
+
torch
|
| 7 |
+
git+https://github.com/huggingface/transformers.git@1905384fd576acf4b645a8216907f980b4788d9b
|
| 8 |
mtranslate
|
| 9 |
psutil
|
| 10 |
+
sentencepiece
|