| from transformers import GPTNeoForCausalLM, GPT2Tokenizer |
| import gradio as gr |
|
|
| model = GPTNeoForCausalLM.from_pretrained("EleutherAI/gpt-neo-125M") |
| tokenizer = GPT2Tokenizer.from_pretrained("EleutherAI/gpt-neo-125M") |
|
|
| prompt = """This is a discussion between a person and a diesel locomotive. |
| person: Are you a diesel locomotive? |
| Diesel locomotive: Yes!! |
| person: Where are you? |
| Diesel locomotive: I live in a railway station!! |
| person: Who are you? |
| Diesel locomotive: I'm a machine!! |
| person: """ |
|
|
| def my_split(s, seps): |
| res = [s] |
| for sep in seps: |
| s, res = res, [] |
| for seq in s: |
| res += seq.split(sep) |
| return res |
|
|
| |
| def chat_base(input): |
| p = prompt + input |
| input_ids = tokenizer(p, return_tensors="pt").input_ids |
| gen_tokens = model.generate(input_ids, do_sample=True, temperature=0.7, max_length=150,) |
| gen_text = tokenizer.batch_decode(gen_tokens)[0] |
| |
| result = gen_text[len(p):] |
| |
| result = my_split(result, [']', '\n'])[1] |
| |
| if "Hassan: " in result: |
| result = result.split("Diesel locomotive: ")[-1] |
| |
| return result |
| |
| import gradio as gr |
|
|
| def chat(message): |
| history = gr.get_state() or [] |
| print(history) |
| response = chat_base(message) |
| history.append((message, response)) |
| gr.set_state(history) |
| html = "<div class='chatbot'>" |
| for user_msg, resp_msg in history: |
| html += f"<div class='user_msg'>{user_msg}</div>" |
| html += f"<div class='resp_msg'>{resp_msg}</div>" |
| html += "</div>" |
| return response |
|
|
| iface = gr.Interface(chat_base, gr.inputs.Textbox(label="Ask a Question"), "text", allow_screenshot=False, allow_flagging=False) |
| iface.launch() |