| import gradio as gr | |
| import torch | |
| import json | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| if torch.cuda.is_available(): | |
| use_cuda = True | |
| else: | |
| use_cuda = False | |
| tokenizer = AutoTokenizer.from_pretrained("keminglu/pivoine-7b", use_auth_token="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", padding_side="left") | |
| model = AutoModelForCausalLM.from_pretrained("keminglu/pivoine-7b", use_auth_token="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", torch_dtype=torch.float16) | |
| model.requires_grad_(False) | |
| model.eval() | |
| if use_cuda: | |
| model = model.to("cuda") | |
| examples = json.load(open("examples.json")) | |
| description = open("description.txt").read() | |
| def inference(context, instruction, num_beams:int=4): | |
| input_str = f"\"{context}\"\n\n{instruction}" | |
| if not input_str.endswith("."): | |
| input_str += "." | |
| input_tokens = tokenizer(input_str, return_tensors="pt", padding=True) | |
| if use_cuda: | |
| for t in input_tokens: | |
| if torch.is_tensor(input_tokens[t]): | |
| input_tokens[t] = input_tokens[t].to("cuda") | |
| output = model.generate( | |
| input_tokens['input_ids'], | |
| num_beams=num_beams, | |
| do_sample=False, | |
| max_new_tokens=2048, | |
| num_return_sequences=1, | |
| return_dict_in_generate=True, | |
| ) | |
| num_input_tokens = input_tokens["input_ids"].shape[1] | |
| output_tokens = output.sequences | |
| generated_tokens = output_tokens[:, num_input_tokens:] | |
| num_generated_tokens = (generated_tokens != tokenizer.pad_token_id).sum(dim=-1).tolist()[0] | |
| prefix_to_add = torch.tensor([[tokenizer("A")["input_ids"][0]]]).to("cuda") | |
| generated_tokens = torch.cat([prefix_to_add, generated_tokens], dim=1) | |
| generated_text = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True) | |
| string_output = [i[1:].strip() for i in generated_text][0] | |
| json_output = None | |
| try: | |
| json_output = json.loads(string_output) | |
| except json.JSONDecodeError: | |
| json_output = {"error": "Unfortunately, there is a JSON decode error on your output, which is really rare in our experiment :("} | |
| except Exception as e: | |
| raise gr.Error(e) | |
| return num_generated_tokens, string_output, json_output | |
| demo = gr.Interface( | |
| fn=inference, | |
| inputs=["text", "text", gr.Slider(1,5,value=4,step=1)], | |
| outputs=[ | |
| gr.Number(label="Number of Generated Tokens"), | |
| gr.Textbox(label="Raw String Output"), | |
| gr.JSON(label="Json Output")], | |
| examples=examples, | |
| examples_per_page=3, | |
| title="Instruction-following Open-world Information Extraction", | |
| description=description, | |
| ) | |
| demo.launch( | |
| show_error=True) |