Spaces:
Runtime error
Runtime error
| from flask import Flask, render_template, redirect, request, jsonify, make_response | |
| import datetime | |
| import torch | |
| import transformers | |
| device = torch.device('cuda') | |
| MODEL_NAME = 'liujch1998/vera' | |
| class Interactive: | |
| def __init__(self): | |
| self.tokenizer = transformers.AutoTokenizer.from_pretrained(MODEL_NAME) | |
| self.model = transformers.T5EncoderModel.from_pretrained(MODEL_NAME, low_cpu_mem_usage=True, device_map='auto', torch_dtype='auto', offload_folder='offload') | |
| self.model.D = self.model.shared.embedding_dim | |
| self.linear = torch.nn.Linear(self.model.D, 1, dtype=self.model.dtype).to(device) | |
| self.linear.weight = torch.nn.Parameter(self.model.shared.weight[32099, :].unsqueeze(0)) # (1, D) | |
| self.linear.bias = torch.nn.Parameter(self.model.shared.weight[32098, 0].unsqueeze(0)) # (1) | |
| self.model.eval() | |
| self.t = self.model.shared.weight[32097, 0].item() | |
| def run(self, statement): | |
| input_ids = self.tokenizer.batch_encode_plus([statement], return_tensors='pt', padding='longest', truncation='longest_first', max_length=128).input_ids.to(device) | |
| with torch.no_grad(): | |
| output = self.model(input_ids) | |
| last_hidden_state = output.last_hidden_state.to(device) # (B=1, L, D) | |
| hidden = last_hidden_state[0, -1, :] # (D) | |
| logit = self.linear(hidden).squeeze(-1) # () | |
| logit_calibrated = logit / self.t | |
| score = logit.sigmoid() | |
| score_calibrated = logit_calibrated.sigmoid() | |
| return { | |
| 'timestamp': datetime.datetime.now().strftime('%Y%m%d-%H%M%S'), | |
| 'statement': statement, | |
| 'logit': logit.item(), | |
| 'logit_calibrated': logit_calibrated.item(), | |
| 'score': score.item(), | |
| 'score_calibrated': score_calibrated.item(), | |
| } | |
| interactive = Interactive() | |
| app = Flask(__name__) | |
| def main(): | |
| try: | |
| print(request) | |
| data = request.get_json() | |
| statement = data.get('statement') | |
| except Exception as e: | |
| return jsonify({ | |
| 'success': False, | |
| 'message': 'Please provide a statement.', | |
| }), 400 | |
| try: | |
| result = interactive.run(statement) | |
| except Exception as e: | |
| return jsonify({ | |
| 'success': False, | |
| 'message': 'Internal error.', | |
| }), 500 | |
| return jsonify(result) | |
| if __name__ == "__main__": | |
| app.run(host="0.0.0.0", port=8372, threaded=True, ssl_context=('/etc/letsencrypt/live/qa.cs.washington.edu/fullchain.pem', '/etc/letsencrypt/live/qa.cs.washington.edu/privkey.pem')) | |
| # 8372 is when you type Vera on a phone keypad | |