|
|
|
|
|
import os |
|
|
import random |
|
|
import torch |
|
|
from transformers import T5Tokenizer, T5ForConditionalGeneration |
|
|
|
|
|
if __package__ is None or __package__ == '': |
|
|
|
|
|
|
|
|
from download_models import download_models |
|
|
else: |
|
|
|
|
|
|
|
|
from .download_models import download_models |
|
|
|
|
|
global tokenizer, model |
|
|
script_dir = os.path.dirname(os.path.abspath(__file__)) |
|
|
modelDir = os.path.join(script_dir, "./model_files/" ) |
|
|
|
|
|
def load_models(): |
|
|
|
|
|
if not all(os.path.exists(modelDir) for file in modelDir): |
|
|
print("Model files not found. Downloading...\n") |
|
|
download_models() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
global tokenizer, model |
|
|
tokenizer = T5Tokenizer.from_pretrained(modelDir) |
|
|
model = T5ForConditionalGeneration.from_pretrained(modelDir, torch_dtype=torch.float16) |
|
|
|
|
|
|
|
|
|
|
|
def unload_models(): |
|
|
global tokenizer, model |
|
|
del tokenizer |
|
|
del model |
|
|
|
|
|
for file in os.listdir(modelDir): |
|
|
os.remove(os.path.join(modelDir, file)) |
|
|
os.rmdir(modelDir) |
|
|
|
|
|
|
|
|
|
|
|
def answer(input_text="", max_new_tokens=512, repetition_penalty=1.2, temperature=0.5, top_p=1, top_k = 1 , seed=-1): |
|
|
|
|
|
if seed == -1: |
|
|
seed = random.randint(1, 1000000) |
|
|
|
|
|
torch.manual_seed(seed) |
|
|
|
|
|
if torch.cuda.is_available(): |
|
|
device = 'cuda' |
|
|
else: |
|
|
device = 'cpu' |
|
|
|
|
|
input_ids = tokenizer(input_text, return_tensors="pt").input_ids.to(device) |
|
|
if torch.cuda.is_available(): |
|
|
model.to('cuda') |
|
|
|
|
|
outputs = model.generate(input_ids, max_new_tokens=max_new_tokens, repetition_penalty=repetition_penalty, |
|
|
do_sample=True, temperature=temperature, top_p=top_p, top_k=top_k) |
|
|
|
|
|
dirty_text = tokenizer.decode(outputs[0]) |
|
|
text = dirty_text.replace("<pad>", "").replace("</s>", "").strip() |
|
|
|
|
|
return text |
|
|
|