| import gradio as gr | |
| import torch | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| tokenizer = AutoTokenizer.from_pretrained("Salesforce/xgen-7b-8k-base", trust_remote_code=True) | |
| model = AutoModelForCausalLM.from_pretrained("Salesforce/xgen-7b-8k-base", torch_dtype=torch.bfloat16) | |
| def gentext(user_input="The world is"): | |
| inputs = tokenizer(user_input, return_tensors="pt") | |
| sample = model.generate(**inputs, max_length=128) | |
| return {"output": tokenizer.decode(sample[0])} | |
| gr.Interface( | |
| gentext, | |
| inputs="text", | |
| outputs="text", | |
| title="Testing out salesforce XGen 7B", | |
| ).launch() |