Spaces:
Runtime error
Runtime error
| #tuto : https://gradio.app/creating_a_chatbot/ | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| import torch | |
| import re | |
| ckpt = 'armandnlp/gpt2-TOD_finetuned_SGD' | |
| tokenizer = AutoTokenizer.from_pretrained(ckpt) | |
| model = AutoModelForCausalLM.from_pretrained(ckpt) | |
| def format_resp(system_resp): | |
| # format Belief, Action and Response tags | |
| system_resp = system_resp.replace('<|belief|>', '*Belief State: ') | |
| system_resp = system_resp.replace('<|action|>', '*Actions: ') | |
| system_resp = system_resp.replace('<|response|>', '*System Response: ') | |
| return system_resp | |
| def predict(input, history=[]): | |
| if history != []: | |
| # model expects only user and system responses, no belief or action sequences | |
| # therefore we clean up the history first. | |
| # history is a list of token ids which represents all the previous states in the conversation | |
| # ie. tokenied user inputs + tokenized model outputs | |
| history_str = tokenizer.decode(history[0]) | |
| turns = re.split('<\|system\|>|<\|user\|>', history_str)[1:] | |
| for i in range(0, len(turns)-1, 2): | |
| turns[i] = '<|user|>' + turns[i] | |
| # keep only the response part of each system_out in the history (no belief and action) | |
| turns[i+1] = '<|system|>' + turns[i+1].split('<|response|>')[1] | |
| history4input = tokenizer.encode(''.join(turns), return_tensors='pt') | |
| else: | |
| history4input = torch.LongTensor(history) | |
| # format input for model by concatenating <|context|> + history4input + new_input + <|endofcontext|> | |
| new_user_input_ids = tokenizer.encode(' <|user|> '+input, return_tensors='pt') | |
| context = tokenizer.encode('<|context|>', return_tensors='pt') | |
| endofcontext = tokenizer.encode(' <|endofcontext|>', return_tensors='pt') | |
| model_input = torch.cat([context, history4input, new_user_input_ids, endofcontext], dim=-1) | |
| # generate output | |
| out = model.generate(model_input, max_length=1024, eos_token_id=50262).tolist()[0] | |
| # formatting the history | |
| # leave out endof... tokens | |
| string_out = tokenizer.decode(out) | |
| system_out = string_out.split('<|endofcontext|>')[1].replace('<|endofbelief|>', '').replace('<|endofaction|>', '').replace('<|endofresponse|>', '') | |
| resp_tokenized = tokenizer.encode(' <|system|> '+system_out, return_tensors='pt') | |
| history = torch.cat([torch.LongTensor(history), new_user_input_ids, resp_tokenized], dim=-1).tolist() | |
| # history = history + last user input + <|system|> <|belief|> ... <|action|> ... <|response|>... | |
| # format responses to print out | |
| # need to output all of the turns, hence why the history must contain belief + action info | |
| # even if we have to take it out of the model input | |
| turns = tokenizer.decode(history[0]) | |
| turns = re.split('<\|system\|>|<\|user\|>', turns)[1:] # list of all the user and system turns until now | |
| # list of tuples [(user, system), (user, system)...] | |
| # 1 tuple represents 1 exchange at 1 turn | |
| # system resp is formatted with function above to make more readable | |
| resps = [(turns[i], format_resp(turns[i+1])) for i in range(0, len(turns)-1, 2)] | |
| return resps, history | |
| examples = [["I want to book a restaurant for 2 people on Saturday."], | |
| ["What's the weather in Cambridge today ?"], | |
| ["I need to find a bus to Boston."], | |
| ["I want to add an event to my calendar."], | |
| ["I would like to book a plane ticket to New York."], | |
| ["I want to find a concert around LA."], | |
| ["Hi, I'd like to find an apartment in London please."], | |
| ["Can you find me a hotel room near Seattle please ?"], | |
| ["I want to watch a film online, a comedy would be nice"], | |
| ["I want to transfer some money please."], | |
| ["I want to reserve a movie ticket for tomorrow evening"], | |
| ["Can you play the song Learning to Fly by Tom Petty ?"], | |
| ["I need to rent a small car."], | |
| ] | |
| description = """ | |
| This is an interactive window to chat with GPT-2 fine-tuned on the Schema-Guided Dialogues dataset, | |
| in which we find domains such as travel, weather, media, calendar, banking, | |
| restaurant booking... | |
| """ | |
| article = """ | |
| ### Model Outputs | |
| This task-oriented diaogue system is trained end-to-end, following the method detailed in | |
| [SimpleTOD](https://arxiv.org/pdf/2005.00796.pdf), where GPT-2 is trained by casting task-oriented | |
| dialogue as a seq2seq task. | |
| From the dialogue history, composed of the previous user and system responses, the model is trained | |
| to output the belief state, the action decisions and the system response as a sequence. We show all | |
| three outputs in this demo : the belief state tracks the user goal (restaurant cuisine : Indian or media | |
| genre : comedy for ex.), the action decisions show how the system should proceed (restaurants request city | |
| or media offer title for ex.) and the natural language response provides an output the user can interpret. | |
| The model responses are *de-lexicalized* : database values in the training set have been replaced with their | |
| slot names to make the learning process database agnostic. These slots are meant to later be replaced by actual | |
| results from a database, using the belief state to issue calls. | |
| The model is capable of dealing with multiple domains : a list of possible inputs is provided to get the | |
| conversation going. | |
| ### Dataset | |
| The SGD dataset ([blogpost](https://ai.googleblog.com/2019/10/introducing-schema-guided-dialogue.html) and | |
| [article](https://arxiv.org/pdf/1909.05855.pdf)) contains multiple task domains... Here is a list of some | |
| of the services and their desctipions from the dataset: | |
| * **Restaurants** : *A leading provider for restaurant search and reservations* | |
| * **Weather** : *Check the weather for any place and any date* | |
| * **Buses** : *Find a bus to take you to the city you want* | |
| * **Calendar** : *Calendar service to manage personal events and reservations* | |
| * **Flights** : *Find your next flight* | |
| * **Events** : *Get tickets for the coolest concerts and sports in your area* | |
| * **Homes** : *A widely used service for finding apartments and scheduling visits* | |
| * **Hotels** : *A popular service for searching and reserving rooms in hotels* | |
| * **Media** : *A leading provider of movies for searching and watching on-demand* | |
| * **Banks** : *Manage bank accounts and transfer money* | |
| * **Movies** : *A go-to provider for finding movies, searching for show times and booking tickets* | |
| * **Music** : *A popular provider of a wide range of music content for searching and listening* | |
| * **RentalCars** : *Car rental service with extensive coverage of locations and cars* | |
| """ | |
| import gradio as gr | |
| gr.Interface(fn=predict, | |
| inputs=["text", "state"], | |
| outputs=["chatbot", "state"], | |
| title="Chatting with multi task-oriented GPT2", | |
| examples=examples, | |
| description=description, | |
| article=article | |
| ).launch() | |