|
|
import spaces |
|
|
import torch |
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
|
|
|
|
|
|
_model = None |
|
|
_tokenizer = None |
|
|
_model_name = "microsoft/DialoGPT-small" |
|
|
|
|
|
def initialize_tokenizer(): |
|
|
"""Initialize tokenizer""" |
|
|
global _tokenizer |
|
|
if _tokenizer is None: |
|
|
print("[MinimalService] Loading tokenizer...") |
|
|
_tokenizer = AutoTokenizer.from_pretrained(_model_name) |
|
|
if _tokenizer.pad_token is None: |
|
|
_tokenizer.pad_token = _tokenizer.eos_token |
|
|
print("[MinimalService] Tokenizer loaded successfully.") |
|
|
return _tokenizer |
|
|
|
|
|
@spaces.GPU |
|
|
def generate_text_gpu(prompt: str, max_tokens: int = 50): |
|
|
"""GPU function for text generation""" |
|
|
global _model, _tokenizer |
|
|
|
|
|
print("[MinimalService] GPU function called") |
|
|
|
|
|
|
|
|
if _tokenizer is None: |
|
|
initialize_tokenizer() |
|
|
|
|
|
|
|
|
if _model is None: |
|
|
print("[MinimalService] Loading model...") |
|
|
_model = AutoModelForCausalLM.from_pretrained( |
|
|
_model_name, |
|
|
torch_dtype=torch.float16, |
|
|
device_map="auto" |
|
|
) |
|
|
print("[MinimalService] Model loaded.") |
|
|
|
|
|
|
|
|
inputs = _tokenizer.encode(prompt, return_tensors="pt") |
|
|
device = next(_model.parameters()).device |
|
|
inputs = inputs.to(device) |
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = _model.generate( |
|
|
inputs, |
|
|
max_new_tokens=max_tokens, |
|
|
temperature=0.7, |
|
|
do_sample=True, |
|
|
pad_token_id=_tokenizer.eos_token_id |
|
|
) |
|
|
|
|
|
response = _tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
return response |
|
|
|
|
|
class MinimalService: |
|
|
def __init__(self): |
|
|
print("[MinimalService] Service initialized") |
|
|
|
|
|
initialize_tokenizer() |
|
|
|
|
|
def generate(self, prompt: str): |
|
|
"""Public method to generate text""" |
|
|
return generate_text_gpu(prompt) |
|
|
|
|
|
|
|
|
service = MinimalService() |
|
|
|
|
|
|
|
|
print(f"[MinimalService] GPU function available: {generate_text_gpu.__name__}") |