Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import os | |
| import torch | |
| import transformers | |
| device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') | |
| class Interactive: | |
| def __init__(self): | |
| self.tokenizer = transformers.AutoTokenizer.from_pretrained('liujch1998/cd-pi', use_auth_token=os.environ['HF_TOKEN_DOWNLOAD']) | |
| self.model = transformers.AutoModelForSeq2SeqLM.from_pretrained('liujch1998/cd-pi', use_auth_token=os.environ['HF_TOKEN_DOWNLOAD']).to(device) | |
| self.linear = torch.nn.Linear(self.model.shared.embedding_dim, 1).to(device) | |
| self.linear.weight = torch.nn.Parameter(self.model.shared.weight[32099, :].unsqueeze(0)) # (1, D) | |
| self.linear.bias = torch.nn.Parameter(self.model.shared.weight[32098, 0].unsqueeze(0)) # (1) | |
| self.model.eval() | |
| self.t = 2.2247 | |
| def run(self, statement): | |
| input_ids = self.tokenizer.batch_encode_plus([statement], return_tensors='pt', padding='longest').input_ids.to(device) | |
| with torch.no_grad(): | |
| output = self.model(input_ids) | |
| last_hidden_state = output.last_hidden_state.to(device) # (B=1, L, D) | |
| hidden = last_hidden_state[0, -1, :] # (D) | |
| logit = self.linear(hidden).squeeze(-1) # () | |
| logit_calibrated = logit / self.t | |
| score = logit.sigmoid() | |
| score_calibrated = logit_calibrated.sigmoid() | |
| return { | |
| 'logit': logit.item(), | |
| 'logit_calibrated': logit_calibrated.item(), | |
| 'score': score.item(), | |
| 'score_calibrated': score_calibrated.item(), | |
| } | |
| interactive = Interactive() | |
| def predict(statement, model): | |
| result = interactive.run(statement) | |
| return { | |
| 'True': result['score_calibrated'], | |
| 'False': 1 - result['score_calibrated'], | |
| } | |
| examples = [ | |
| 'If A sits next to B and B sits next to C, then A must sit next to C.', | |
| 'If A sits next to B and B sits next to C, then A might not sit next to C.', | |
| ] | |
| input_statement = gr.Dropdown(choices=examples, label='Statement:') | |
| input_model = gr.Textbox(label='Commonsense statement verification model:', value='liujch1998/cd-pi', interactive=False) | |
| output = gr.outputs.Label(num_top_classes=2) | |
| description = '''This is a demo for a commonsense statement verification model. Under development.''' | |
| gr.Interface( | |
| fn=predict, | |
| inputs=[input_statement, input_model], | |
| outputs=output, | |
| title="cd-pi Demo", | |
| description=description, | |
| ).launch() | |