Spaces:
Runtime error
Runtime error
| # ------------------- LIBRARIES -------------------- # | |
| import os, logging, torch, streamlit as st | |
| from transformers import ( | |
| AutoTokenizer, AutoModelForCausalLM) | |
| # --------------------- HELPER --------------------- # | |
| def C(text, color="yellow"): | |
| color_dict: dict = dict( | |
| red="\033[01;31m", | |
| green="\033[01;32m", | |
| yellow="\033[01;33m", | |
| blue="\033[01;34m", | |
| magenta="\033[01;35m", | |
| cyan="\033[01;36m", | |
| ) | |
| color_dict[None] = "\033[0m" | |
| return ( | |
| f"{color_dict.get(color, None)}" | |
| f"{text}{color_dict[None]}") | |
| def stcache(): | |
| from packaging import version | |
| if version.parse(st.__version__) < version.parse("1.18"): | |
| return lambda f: st.cache(suppress_st_warning=True)(f) | |
| return lambda f: st.cache_resource()(f) | |
| st.title("`ckip-joint/bloom-1b1-zh` demo") | |
| # ------------------ ENVIORNMENT ------------------- # | |
| os.environ["HF_ENDPOINT"] = "https://huggingface.co" | |
| device = ("cuda" | |
| if torch.cuda.is_available() else "cpu") | |
| logging.info(C("[INFO] "f"device = {device}")) | |
| # ------------------ INITITALIZE ------------------- # | |
| stdec = stcache() | |
| def model_init(): | |
| logging.info(C("[INFO] "f"Model init start!")) | |
| from transformers import GenerationConfig | |
| # generation_config, unused_kwargs = GenerationConfig.from_pretrained( | |
| # "ckip-joint/bloom-1b1-zh", | |
| # max_new_tokens=200, | |
| # return_unused_kwargs=True) | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| "ckip-joint/bloom-1b1-zh") | |
| model = AutoModelForCausalLM.from_pretrained( | |
| "ckip-joint/bloom-1b1-zh", | |
| # Ref.: Eric, Thanks! | |
| # torch_dtype="auto", | |
| # device_map="auto", | |
| # Ref. for `half`: Chan-Jan, Thanks! | |
| ).eval().to(device) | |
| st.balloons() | |
| logging.info(C("[INFO] "f"Model init success!")) | |
| return tokenizer, model | |
| tokenizer, model = model_init() | |
| if 1: | |
| try: | |
| # ===================== INPUT ====================== # | |
| prompt = st.text_input("Prompt: ") | |
| # =================== INFERENCE ==================== # | |
| if prompt: | |
| # placeholder = st.empty() | |
| # st.title(prompt) | |
| with st.container(): | |
| st.markdown(f"" | |
| f":violet[{prompt}]⋯⋯" | |
| ) | |
| # st.empty() | |
| with torch.no_grad(): | |
| [texts_out] = model.generate( | |
| **tokenizer( | |
| prompt, return_tensors="pt", | |
| ).to(device), | |
| min_new_tokens=0, | |
| max_new_tokens=100, | |
| ) | |
| output_text = tokenizer.decode(texts_out, | |
| skip_special_tokens=True, | |
| ) | |
| st.empty() | |
| if output_text.startswith(prompt): | |
| out_gens = output_text[len(prompt):] | |
| assert prompt + out_gens == output_text | |
| else: | |
| out_gens = output_text | |
| prompt = "" | |
| st.balloons() | |
| out_gens = out_gens.split('\n')[0] | |
| def multiline(string): | |
| lines = string.split('\n') | |
| return '\\\n'.join([f"**:red[{l}]**" | |
| for l in lines]) | |
| # st.empty() | |
| st.caption("Result: ") | |
| st.markdown(f"" | |
| f":blue[{prompt}]**:red[{multiline(out_gens)}]**" | |
| ) | |
| # st.text(repr(out_gens0)) | |
| except Exception as err: | |
| st.write(str(err)) | |
| st.snow() | |
| # import streamlit as st | |
| # st.markdown('Streamlit is **_really_ cool**.') | |
| # st.markdown("This text is :red[colored red], and this is **:blue[colored]** and bold.") | |
| # st.markdown(":green[$\sqrt{x^2+y^2}=1$] is a Pythagorean identity. :pencil:") | |
| # def multiline(string): | |
| # lines = string.split('\n') | |
| # return '\\\n'.join([f"**:red[{l}]**" | |
| # for l in lines]) | |
| # st.markdown(multiline("1234 \n5616")) | |
| # st.markdown("1234\\\n5616") | |
| # https://docs.streamlit.io/library/api-reference/status/st.spinner | |
| # https://stackoverflow.com/questions/32402502/how-to-change-the-time-zone-in-python-logging |