Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import os | |
| import torch | |
| import transformers | |
| import huggingface_hub | |
| import datetime | |
| import json | |
| import shutil | |
| device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') | |
| HF_TOKEN_DOWNLOAD = os.environ['HF_TOKEN_DOWNLOAD'] | |
| HF_TOKEN_UPLOAD = os.environ['HF_TOKEN_UPLOAD'] | |
| MODEL_NAME = 'liujch1998/cd-pi' | |
| DATASET_REPO_URL = "https://huggingface.co/datasets/liujch1998/cd-pi-dataset" | |
| DATA_DIR = 'data' | |
| DATA_PATH = os.path.join(DATA_DIR, 'data.jsonl') | |
| try: | |
| shutil.rmtree(DATA_DIR) | |
| except: | |
| pass | |
| repo = huggingface_hub.Repository( | |
| local_dir=DATA_DIR, | |
| clone_from=DATASET_REPO_URL, | |
| token=HF_TOKEN_UPLOAD, | |
| repo_type='dataset', | |
| ) | |
| repo.git_pull() | |
| class Interactive: | |
| def __init__(self): | |
| self.tokenizer = transformers.AutoTokenizer.from_pretrained(MODEL_NAME, use_auth_token=HF_TOKEN_DOWNLOAD) | |
| self.model = transformers.T5EncoderModel.from_pretrained(MODEL_NAME, use_auth_token=HF_TOKEN_DOWNLOAD, low_cpu_mem_usage=True, device_map='auto', torch_dtype='auto') | |
| self.linear = torch.nn.Linear(self.model.shared.embedding_dim, 1, dtype=self.model.dtype).to(device) | |
| self.linear.weight = torch.nn.Parameter(self.model.shared.weight[32099, :].unsqueeze(0)) # (1, D) | |
| self.linear.bias = torch.nn.Parameter(self.model.shared.weight[32098, 0].unsqueeze(0)) # (1) | |
| self.model.eval() | |
| self.t = self.model.shared.weight[32097, 0].item() | |
| def run(self, statement): | |
| input_ids = self.tokenizer.batch_encode_plus([statement], return_tensors='pt', padding='longest').input_ids.to(device) | |
| with torch.no_grad(): | |
| output = self.model(input_ids) | |
| last_hidden_state = output.last_hidden_state.to(device) # (B=1, L, D) | |
| hidden = last_hidden_state[0, -1, :] # (D) | |
| logit = self.linear(hidden).squeeze(-1) # () | |
| logit_calibrated = logit / self.t | |
| score = logit.sigmoid() | |
| score_calibrated = logit_calibrated.sigmoid() | |
| return { | |
| 'logit': logit.item(), | |
| 'logit_calibrated': logit_calibrated.item(), | |
| 'score': score.item(), | |
| 'score_calibrated': score_calibrated.item(), | |
| } | |
| # return { | |
| # 'logit': 0.0, | |
| # 'logit_calibrated': 0.0, | |
| # 'score': 0.5, | |
| # 'score_calibrated': 0.5, | |
| # } | |
| interactive = Interactive() | |
| def predict(statement): | |
| result = interactive.run(statement) | |
| with open(DATA_PATH, 'a') as f: | |
| row = { | |
| 'timestamp': datetime.datetime.now().strftime('%Y%m%d-%H%M%S'), | |
| 'statement': statement, | |
| 'logit': result['logit'], | |
| 'logit_calibrated': result['logit_calibrated'], | |
| 'score': result['score'], | |
| 'score_calibrated': result['score_calibrated'], | |
| } | |
| json.dump(row, f, ensure_ascii=False) | |
| f.write('\n') | |
| commit_url = repo.push_to_hub() | |
| print(commit_url) | |
| return { | |
| 'True': result['score_calibrated'], | |
| 'False': 1 - result['score_calibrated'], | |
| } | |
| examples = [ | |
| # openbookqa | |
| 'If a person walks in the opposite direction of a compass arrow they are walking south.', | |
| 'If a person walks in the opposite direction of a compass arrow they are walking north.', | |
| # arc_easy | |
| 'A pond is different from a lake because ponds are smaller and shallower.', | |
| 'A pond is different from a lake because ponds have moving water.', | |
| # arc_hard | |
| 'Hunting strategies are more likely to be learned rather than inherited.', | |
| 'A spotted coat is more likely to be learned rather than inherited.', | |
| # ai2_science_elementary | |
| 'Photosynthesis uses carbon from the air to make food for plants.', | |
| 'Respiration uses carbon from the air to make food for plants.', | |
| # ai2_science_middle | |
| 'The barometer measures atmospheric pressure.', | |
| 'The thermometer measures atmospheric pressure.', | |
| # commonsenseqa | |
| 'People aim to complete a job at work.', | |
| 'People aim to kill animals at work.', | |
| # qasc | |
| 'Climate is generally described in terms of local weather conditions.', | |
| 'Climate is generally described in terms of forests.', | |
| # physical_iqa | |
| 'ice box will turn into a cooler if you add water to it.', | |
| 'ice box will turn into a cooler if you add soda to it.', | |
| # social_iqa | |
| 'Kendall opened their mouth to speak and what came out shocked everyone. Kendall is a very aggressive and talkative person.', | |
| 'Kendall opened their mouth to speak and what came out shocked everyone. Kendall is a very quiet person.', | |
| # winogrande_xl | |
| 'Sarah was a much better surgeon than Maria so Maria always got the easier cases.', | |
| 'Sarah was a much better surgeon than Maria so Sarah always got the easier cases.', | |
| # com2sense_paired | |
| 'If you want a quick snack, getting one banana would be a good choice generally.', | |
| 'If you want a snack, getting twenty bananas would be a good choice generally.', | |
| # sciq | |
| 'Each specific polypeptide has a unique linear sequence of amino acids.', | |
| 'Each specific polypeptide has a unique linear sequence of fatty acids.', | |
| # quarel | |
| 'Tommy glided across the marble floor with ease, but slipped and fell on the wet floor because wet floor has more resistance.', | |
| 'Tommy glided across the marble floor with ease, but slipped and fell on the wet floor because marble floor has more resistance.', | |
| # quartz | |
| 'If less waters falls on an area of land it will cause less plants to grow in that area.', | |
| 'If less waters falls on an area of land it will cause more plants to grow in that area.', | |
| # cycic_mc | |
| 'In U.S. spring, Rob visits the financial district every day. In U.S. winter, Rob visits the park every day. Rob will go to the park on January 20.', | |
| 'In U.S. spring, Rob visits the financial district every day. In U.S. winter, Rob visits the park every day. Rob will go to the financial district on January 20.', | |
| # comve_a | |
| 'Summer in North America is great for swimming, boating, and fishing.', | |
| 'Summer in North America is great for skiing, snowshoeing, and making a snowman.', | |
| # csqa2 | |
| 'Gas is always capable of turning into liquid under high pressure.', | |
| 'Cotton candy is sometimes made out of cotton.', | |
| # symkd_anno | |
| 'James visits a famous landmark. As a result, James learns about the world.', | |
| 'Cliff and Andrew enter the castle. But before, Cliff needed to have been a student at the school.', | |
| # gengen_anno | |
| 'Generally, bar patrons are capable of taking care of their own drinks.', | |
| 'Generally, ocean currents have little influence over storm intensity.', | |
| # 'If A sits next to B and B sits next to C, then A must sit next to C.', | |
| # 'If A sits next to B and B sits next to C, then A might not sit next to C.', | |
| ] | |
| input_statement = gr.Dropdown(choices=examples, label='Statement:') | |
| input_model = gr.Textbox(label='Commonsense statement verification model:', value=MODEL_NAME, interactive=False) | |
| output = gr.outputs.Label(num_top_classes=2) | |
| description = '''This is a demo for a commonsense statement verification model. Under development.''' | |
| gr.Interface( | |
| fn=predict, | |
| inputs=[input_statement], | |
| outputs=output, | |
| title="cd-pi Demo", | |
| description=description, | |
| ).launch() | |