kb-agent / mcp_server /model.py
dim014's picture
Upload 16 files
d009748 verified
raw
history blame contribute delete
801 Bytes
from transformers import T5Tokenizer, T5ForConditionalGeneration
import torch
from .config import MODEL_NAME, DEVICE
class LocalModel:
def __init__(self):
self.tokenizer = T5Tokenizer.from_pretrained(MODEL_NAME)
self.model = T5ForConditionalGeneration.from_pretrained(MODEL_NAME)
self.model.to(DEVICE)
def generate(self, prompt: str, max_tokens: int = 128, temperature: float = 0.5) -> str:
input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids.to(DEVICE)
outputs = self.model.generate(
input_ids,
max_new_tokens=max_tokens,
temperature=temperature,
top_p=0.9,
do_sample=True
)
return self.tokenizer.decode(outputs[0], skip_special_tokens=True)