Spaces:
Build error
Build error
| """Python file to serve as the frontend""" | |
| import streamlit as st | |
| from streamlit_chat import message | |
| from langchain.chains import ConversationChain, LLMChain | |
| from langchain import PromptTemplate | |
| from langchain.llms.base import LLM | |
| from langchain.memory import ConversationBufferWindowMemory | |
| from typing import Optional, List, Mapping, Any | |
| import torch | |
| from peft import PeftModel | |
| import transformers | |
| from transformers import LlamaTokenizer, LlamaForCausalLM, GenerationConfig | |
| from transformers import BitsAndBytesConfig | |
| tokenizer = LlamaTokenizer.from_pretrained("decapoda-research/llama-7b-hf") | |
| quantization_config = BitsAndBytesConfig(llm_int8_enable_fp32_cpu_offload=True) | |
| model = LlamaForCausalLM.from_pretrained( | |
| "decapoda-research/llama-7b-hf", | |
| # load_in_8bit=True, | |
| # torch_dtype=torch.float16, | |
| device_map="auto", | |
| # device_map={"":"cpu"}, | |
| max_memory={"cpu":"15GiB"} | |
| quantization_config=quantization_config | |
| ) | |
| model = PeftModel.from_pretrained( | |
| model, "tloen/alpaca-lora-7b", | |
| # torch_dtype=torch.float16, | |
| device_map={"":"cpu"}, | |
| ) | |
| device = "cpu" | |
| print("model device :", model.device, flush=True) | |
| # model.to(device) | |
| model.eval() | |
| def evaluate_raw_prompt( | |
| prompt:str, | |
| temperature=0.1, | |
| top_p=0.75, | |
| top_k=40, | |
| num_beams=4, | |
| **kwargs, | |
| ): | |
| inputs = tokenizer(prompt, return_tensors="pt") | |
| input_ids = inputs["input_ids"].to(device) | |
| generation_config = GenerationConfig( | |
| temperature=temperature, | |
| top_p=top_p, | |
| top_k=top_k, | |
| num_beams=num_beams, | |
| **kwargs, | |
| ) | |
| with torch.no_grad(): | |
| generation_output = model.generate( | |
| input_ids=input_ids, | |
| generation_config=generation_config, | |
| return_dict_in_generate=True, | |
| output_scores=True, | |
| max_new_tokens=256, | |
| ) | |
| s = generation_output.sequences[0] | |
| output = tokenizer.decode(s) | |
| # return output | |
| return output.split("### Response:")[1].strip() | |
| class AlpacaLLM(LLM): | |
| temperature: float | |
| top_p: float | |
| top_k: int | |
| num_beams: int | |
| def _llm_type(self) -> str: | |
| return "custom" | |
| def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str: | |
| if stop is not None: | |
| raise ValueError("stop kwargs are not permitted.") | |
| answer = evaluate_raw_prompt(prompt, | |
| top_p= self.top_p, | |
| top_k= self.top_k, | |
| num_beams= self.num_beams, | |
| temperature= self.temperature | |
| ) | |
| return answer | |
| def _identifying_params(self) -> Mapping[str, Any]: | |
| """Get the identifying parameters.""" | |
| return { | |
| "top_p": self.top_p, | |
| "top_k": self.top_k, | |
| "num_beams": self.num_beams, | |
| "temperature": self.temperature | |
| } | |
| template = """Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request. | |
| ### Instruction: | |
| You are a chatbot, you should answer my last question very briefly. You are consistent and non repetitive. | |
| ### Chat: | |
| {history} | |
| Human: {human_input} | |
| ### Response:""" | |
| prompt = PromptTemplate( | |
| input_variables=["history","human_input"], | |
| template=template, | |
| ) | |
| def load_chain(): | |
| """Logic for loading the chain you want to use should go here.""" | |
| llm = AlpacaLLM(top_p=0.75, top_k=40, num_beams=4, temperature=0.1) | |
| # chain = ConversationChain(llm=llm) | |
| chain = LLMChain(llm=llm, prompt=prompt, memory=ConversationBufferWindowMemory(k=2)) | |
| return chain | |
| chain = load_chain() | |
| # # From here down is all the StreamLit UI. | |
| # st.set_page_config(page_title="LangChain Demo", page_icon=":robot:") | |
| # st.header("LangChain Demo") | |
| # if "generated" not in st.session_state: | |
| # st.session_state["generated"] = [] | |
| # if "past" not in st.session_state: | |
| # st.session_state["past"] = [] | |
| # def get_text(): | |
| # input_text = st.text_input("Human: ", "Hello, how are you?", key="input") | |
| # return input_text | |
| # user_input = get_text() | |
| # if user_input: | |
| # output = chain.predict(human_input=user_input) | |
| # st.session_state.past.append(user_input) | |
| # st.session_state.generated.append(output) | |
| # if st.session_state["generated"]: | |
| # for i in range(len(st.session_state["generated"]) - 1, -1, -1): | |
| # message(st.session_state["generated"][i], key=str(i)) | |
| # message(st.session_state["past"][i], is_user=True, key=str(i) + "_user") | |
| st.title("ChatAlpaca") | |
| if "history" not in st.session_state: | |
| st.session_state.history = [] | |
| st.session_state.history.append({"message": "Hey, I'm a Alpaca chatBot. Ask whatever you want!", "is_user": False}) | |
| def generate_answer(): | |
| user_message = st.session_state.input_text | |
| inputs = tokenizer(st.session_state.input_text, return_tensors="pt") | |
| result = model.generate(**inputs) | |
| message_bot = tokenizer.decode(result[0], skip_special_tokens=True) # .replace("<s>", "").replace("</s>", "") | |
| st.session_state.history.append({"message": user_message, "is_user": True}) | |
| st.session_state.history.append({"message": message_bot, "is_user": False}) | |
| st.text_input("Response", key="input_text", on_change=generate_answer) | |
| for chat in st.session_state.history: | |
| st_message(**chat) |