chat-perche / app.py
abstractmachine's picture
Changed dtype in torch to float32
3ce3629
raw
history blame contribute delete
864 Bytes
import gradio as gr
from transformers import AutoTokenizer
from transformers import pipeline
import transformers
import torch
# get the model path
model = "headmediadesign/bloom-perchay"
# prepare the tokenzier
tokenizer = AutoTokenizer.from_pretrained(model)
print("tokenizer: " + tokenizer.name_or_path)
# prepare the pipeline
pipeline = transformers.pipeline(
"text-generation",
model=model,
#torch_dtype=torch.float16,
torch_dtype=torch.float32,
device_map="auto",
)
print("pipeline: " + pipeline.model.name_or_path)
def generate(prompt):
output = ""
sequences = pipeline(
prompt,
do_sample=True,
return_full_text=False,
top_k=500,
num_return_sequences=1,
eos_token_id=tokenizer.eos_token_id,
max_length=1000,
)
return sequences[0]['generated_text']
iface = gr.Interface(fn=generate, inputs="text", outputs="text")
iface.launch()