Spaces:
Runtime error
Runtime error
Change SDK
Browse files
README.md
CHANGED
|
@@ -3,8 +3,8 @@ title: Bloom Chat
|
|
| 3 |
emoji: ⚡
|
| 4 |
colorFrom: purple
|
| 5 |
colorTo: green
|
| 6 |
-
sdk:
|
| 7 |
-
sdk_version:
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
| 10 |
license: openrail
|
|
|
|
| 3 |
emoji: ⚡
|
| 4 |
colorFrom: purple
|
| 5 |
colorTo: green
|
| 6 |
+
sdk: streamlit
|
| 7 |
+
sdk_version: 1.10.0
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
| 10 |
license: openrail
|
app.py
CHANGED
|
@@ -1,27 +1,104 @@
|
|
| 1 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
|
| 3 |
-
# text = st.text_area("Prefix", value="DM: You enter the room.")
|
| 4 |
-
# batch = st.number_input("Variants", value=5)
|
| 5 |
-
# st.markdown(f"{text} {batch}")
|
| 6 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
|
| 12 |
-
|
|
|
|
| 13 |
|
| 14 |
-
|
|
|
|
|
|
|
|
|
|
| 15 |
|
| 16 |
-
|
|
|
|
|
|
|
| 17 |
|
| 18 |
-
|
| 19 |
-
|
|
|
|
|
|
|
| 20 |
|
| 21 |
-
|
| 22 |
-
col1.image(image, use_column_width=True)
|
| 23 |
-
predictions = pipeline(image)
|
| 24 |
|
| 25 |
-
|
| 26 |
-
for p in predictions:
|
| 27 |
-
col2.subheader(f"{ p['label'] }: { round(p['score'] * 100, 1)}%")
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import transformers
|
| 3 |
+
import time
|
| 4 |
+
from huggingface_hub import snapshot_download
|
| 5 |
+
import streamlit as st
|
| 6 |
+
import copy
|
| 7 |
+
from transformers import AutoConfig, GPTJForCausalLM
|
| 8 |
+
from transformers.models.gptj.modeling_gptj import GPTJBlock
|
| 9 |
+
from tqdm import trange
|
| 10 |
|
|
|
|
|
|
|
|
|
|
| 11 |
|
| 12 |
+
@st.cache(allow_output_mutation=True)
|
| 13 |
+
def load_model():
|
| 14 |
+
for down in trange(1, disable=True):
|
| 15 |
+
fpath = snapshot_download("OpenDungeon/gpt-j-8bit-ffbgem", revision="separate")
|
| 16 |
+
config = AutoConfig.from_pretrained("EleutherAI/gpt-j-6B")
|
| 17 |
+
qconfig = torch.quantization.get_default_qconfig('fbgemm')
|
| 18 |
+
torch.backends.quantized.engine = 'fbgemm'
|
| 19 |
+
n_layer, config.n_layer = config.n_layer, 0
|
| 20 |
|
| 21 |
+
model = GPTJForCausalLM(config)
|
| 22 |
+
model.load_state_dict(torch.load(fpath + "/blocks/base.pt"))
|
| 23 |
+
ref_block = torch.quantization.quantize_dynamic(
|
| 24 |
+
GPTJBlock(config),
|
| 25 |
+
{torch.nn.Linear: qconfig},
|
| 26 |
+
dtype=torch.qint8,
|
| 27 |
+
inplace=True
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
for i in trange(n_layer):
|
| 31 |
+
new_block = copy.deepcopy(ref_block)
|
| 32 |
+
new_block.load_state_dict(torch.load(f"{fpath}/blocks/block{i}.pt"))
|
| 33 |
+
model.transformer.h.append(new_block)
|
| 34 |
+
|
| 35 |
+
config.n_layer = len(model.transformer.h)
|
| 36 |
+
del ref_block
|
| 37 |
+
|
| 38 |
+
return transformers.AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B"), model
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def PrintContinuation(prompt, local_model, single_hook=None, batch=1, limit_tokens = 50):
|
| 42 |
+
past_key_values = None # used to keep track of conversation history
|
| 43 |
+
input_dict = tokenizer([prompt] * batch, return_tensors='pt', padding=False)
|
| 44 |
+
output = [""] * batch
|
| 45 |
+
batch_time = 0
|
| 46 |
+
|
| 47 |
+
with torch.inference_mode():
|
| 48 |
+
for i in range(limit_tokens + 20):
|
| 49 |
+
if i == 5:
|
| 50 |
+
start_time = time.perf_counter()
|
| 51 |
+
|
| 52 |
+
outputs = local_model.forward(**input_dict, use_cache=True, past_key_values=past_key_values)
|
| 53 |
+
last_logits = outputs.logits[:, -1]
|
| 54 |
+
|
| 55 |
+
for j in range(batch):
|
| 56 |
+
last_logits[j, last_logits[j].topk(k=10).indices] += 10
|
| 57 |
+
|
| 58 |
+
past_key_values = outputs.past_key_values
|
| 59 |
+
token_ix = torch.multinomial(last_logits.softmax(-1), 1)
|
| 60 |
+
output = [stream + tokenizer.decode(ix) for stream, ix in zip(output, token_ix)]
|
| 61 |
+
|
| 62 |
+
if single_hook is not None:
|
| 63 |
+
single_hook(tokenizer.decode(token_ix[0]))
|
| 64 |
+
if i == limit_tokens:
|
| 65 |
+
batch_time = (time.perf_counter() - start_time) / (i - 4)
|
| 66 |
+
break
|
| 67 |
+
|
| 68 |
+
input_dict = dict(input_ids=token_ix)
|
| 69 |
+
return output, batch_time
|
| 70 |
+
|
| 71 |
+
import sys
|
| 72 |
+
|
| 73 |
+
def Sureprint(text):
|
| 74 |
+
text = f"\nDDBG: {text}\n"
|
| 75 |
+
print(text, flush=True)
|
| 76 |
+
print(text, file=sys.stderr, flush=True)
|
| 77 |
+
|
| 78 |
+
Sureprint("ready to load")
|
| 79 |
+
tokenizer, model = load_model()
|
| 80 |
+
Sureprint("loaded")
|
| 81 |
+
text = st.text_area("Prefix", value="DM: You enter the room.")
|
| 82 |
+
Sureprint(f"text acquired '{text}'")
|
| 83 |
+
batch = st.number_input("Variants", value=5)
|
| 84 |
|
| 85 |
+
t = st.empty()
|
| 86 |
+
firstline = ""
|
| 87 |
|
| 88 |
+
def PrintSome(text):
|
| 89 |
+
global t, firstline
|
| 90 |
+
firstline += text
|
| 91 |
+
t.markdown(f"{firstline}...")
|
| 92 |
|
| 93 |
+
Sureprint("before inference")
|
| 94 |
+
choices, batch_time = PrintContinuation(text, model, PrintSome, batch, 50)
|
| 95 |
+
Sureprint("after inference")
|
| 96 |
|
| 97 |
+
final_page = ""
|
| 98 |
+
for i in range(batch):
|
| 99 |
+
final_page += f"#### choice №{i + 1} \n{choices[i]} \n______ \n"
|
| 100 |
+
final_page += f"Seconds per batch: {batch_time}, Batch: {batch}"
|
| 101 |
|
| 102 |
+
t.markdown(final_page)
|
|
|
|
|
|
|
| 103 |
|
| 104 |
+
Sureprint("all done")
|
|
|
|
|
|