Spaces:
Sleeping
Sleeping
File size: 2,440 Bytes
d3530f3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 |
"""
This module contains all tasks related to transformer model
Refactored to use Hugging Face Flan-T5 (no fastT5 dependency)
@Author: Karthick T. Sharma
@Modified: LinhGPT
"""
import os
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
class Model:
"""Generalized T5/Flan-T5 model for text generation."""
def __init__(self, model_name: str = "google/flan-t5-base"):
"""
Load model and tokenizer into memory.
Args:
model_name (str): Name or path of the Hugging Face model.
"""
os.environ["TOKENIZERS_PARALLELISM"] = "false"
print(f"🔹 Loading model: {model_name} ...")
self.__tokenizer = AutoTokenizer.from_pretrained(model_name)
self.__model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
print("✅ Model and tokenizer loaded successfully.\n")
def tokenize_corpus(self, text: str, max_length: int):
"""Tokenize model input text."""
encode = self.__tokenizer.encode_plus(
text,
return_tensors="pt",
max_length=max_length,
truncation=True,
padding="max_length",
)
return encode["input_ids"], encode["attention_mask"]
def __extract_dict(self, input_dict):
"""Extract key-value pairs into a string format."""
return " ".join(f"{k}: {v}" for k, v in input_dict.items())
def inference(
self,
num_beams: int = 4,
no_repeat_ngram_size: int = 2,
model_max_length: int = 128,
num_return_sequences: int = 1,
token_max_length: int = 256,
**kwargs,
):
"""
Generate model output text.
"""
text = self.__extract_dict(kwargs)
input_ids, attention_mask = self.tokenize_corpus(text, token_max_length)
outputs = self.__model.generate(
input_ids=input_ids,
attention_mask=attention_mask,
num_beams=num_beams,
num_return_sequences=num_return_sequences,
no_repeat_ngram_size=no_repeat_ngram_size,
max_length=model_max_length,
early_stopping=True,
)
decoded = [
self.__tokenizer.decode(
output, skip_special_tokens=True, clean_up_tokenization_spaces=True
)
for output in outputs
]
return decoded[0] if num_return_sequences == 1 else decoded
|