1142 / src /models /flan.py
longdiyao's picture
Update src/models/flan.py
a2d7478 verified
raw
history blame contribute delete
871 Bytes
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import torch
class FlanT5:
def __init__(self):
model_id = "google/flan-t5-base"
self.tokenizer = AutoTokenizer.from_pretrained(model_id)
self.model = AutoModelForSeq2SeqLM.from_pretrained(
model_id,
device_map="auto"
).eval()
def generate(self, prompt, temperature=0.7):
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
outputs = self.model.generate(
**inputs,
do_sample=True,
temperature=temperature,
top_p=0.9,
max_new_tokens=64,
no_repeat_ngram_size=2,
eos_token_id=self.tokenizer.eos_token_id
)
return self.tokenizer.decode(outputs[0], skip_special_tokens=True, clean_up_tokenization_spaces=True)