Manuel2011 commited on
Commit
32b4fdd
·
1 Parent(s): e37ac90

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +45 -0
app.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoModelForCausalLM
2
+ import torch
3
+ import gradio as gr
4
+
5
+ model = AutoModelForCausalLM.from_pretrained("models/addition_model")
6
+
7
+ class NumberTokenizer:
8
+ def __init__(self, numbers_qty=10):
9
+ vocab = ['+', '=', '-1', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9']
10
+ self.numbers_qty = numbers_qty
11
+ self.pad_token = '-1'
12
+ self.encoder = {str(v):i for i,v in enumerate(vocab)}
13
+ self.decoder = {i:str(v) for i,v in enumerate(vocab)}
14
+ self.pad_token_id = self.encoder[self.pad_token]
15
+
16
+ def decode(self, token_ids):
17
+ return ' '.join(self.decoder[t] for t in token_ids)
18
+
19
+ def __call__(self, text):
20
+ return [self.encoder[t] for t in text.split()]
21
+
22
+ tokenizer = NumberTokenizer(13)
23
+
24
+ def generate_solution(input, solution_length=6, model=model):
25
+ model.eval()
26
+ input = torch.tensor(tokenizer(input))
27
+ input = input
28
+ solution = []
29
+ for i in range(solution_length):
30
+ output = model(input)
31
+ predicted = output.logits[-1].argmax()
32
+ input = torch.cat((input, predicted.unsqueeze(0)), dim=0)
33
+ solution.append(predicted.cpu().item())
34
+ return tokenizer.decode(solution)
35
+
36
+ def solve(input):
37
+ return generate_solution(input, solution_length=2)
38
+
39
+ demo = gr.Interface(fn=solve, inputs=[gr.Textbox(label="Addition exercise", lines=1, info="The input must be of the form '1 + 2 =', with a single space between each character, and only single-digit numbers are allowed.")],
40
+ outputs=[gr.Textbox(label="Result", lines=1)],
41
+ title="Simple addition with a GPT-like model",
42
+ description="Perform addition of two single-digit numbers using a GPT-like model trained on a small dataset.",
43
+ examples=["1 + 2 =", "5 + 7 ="])
44
+
45
+ demo.launch()