Spaces:
Build error
Build error
| from transformers import AutoModelForCausalLM | |
| import torch | |
| import gradio as gr | |
| import re | |
| model = AutoModelForCausalLM.from_pretrained("Manuel2011/addition_model") | |
| class NumberTokenizer: | |
| def __init__(self, numbers_qty=10): | |
| vocab = ['+', '=', '-1', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9'] | |
| self.numbers_qty = numbers_qty | |
| self.pad_token = '-1' | |
| self.encoder = {str(v):i for i,v in enumerate(vocab)} | |
| self.decoder = {i:str(v) for i,v in enumerate(vocab)} | |
| self.pad_token_id = self.encoder[self.pad_token] | |
| def decode(self, token_ids): | |
| return ' '.join(self.decoder[t] for t in token_ids) | |
| def __call__(self, text): | |
| return [self.encoder[t] for t in text.split()] | |
| tokenizer = NumberTokenizer(13) | |
| def generate_solution(input, solution_length=6, model=model): | |
| try: | |
| parsed_input = re.search(r'(\d)\s*\+\s*(\d)', input) | |
| first_number = int(parsed_input.group(1)) | |
| second_number = int(parsed_input.group(2)) | |
| except: | |
| return 'Invalid input' | |
| model.eval() | |
| input = f'{first_number} + {second_number} =' | |
| input = torch.tensor(tokenizer(input)) | |
| input = input | |
| solution = [] | |
| for i in range(solution_length): | |
| output = model(input) | |
| predicted = output.logits[-1].argmax() | |
| input = torch.cat((input, predicted.unsqueeze(0)), dim=0) | |
| solution.append(predicted.cpu().item()) | |
| return tokenizer.decode(solution) | |
| def solve(input): | |
| return generate_solution(input, solution_length=2) | |
| demo = gr.Interface(fn=solve, inputs=[gr.Textbox(label="Addition exercise", lines=1, info="The input must be of the form '1 + 2 =', with a single space between each character, and only single-digit numbers are allowed.")], | |
| outputs=[gr.Textbox(label="Result", lines=1)], | |
| title="Simple addition with a GPT-like model", | |
| description="Perform addition of two single-digit numbers using a GPT-like model trained on a small dataset.", | |
| examples=["1 + 2 =", "5 + 7 ="]) | |
| demo.launch() |