neon-8-qubits / handler.py
MarkChenX's picture
Upload 9 files
706c147 verified
import torch
import torch.nn.functional as F
from transformers import GPT2Tokenizer, PreTrainedModel, PretrainedConfig
# Custom Configuration
from transformers import GPT2Config
from transformers.models.auto.configuration_auto import CONFIG_MAPPING
class CustomGPTConfig(GPT2Config):
model_type = "custom_gpt"
def __init__(self, vocab_size=50304, n_layer=24, n_head=16, hidden_size=1024, block_size=1024, **kwargs):
super().__init__(
vocab_size=vocab_size,
n_positions=block_size,
n_ctx=block_size,
n_embd=hidden_size,
n_layer=n_layer,
n_head=n_head,
**kwargs,
)
self.block_size = block_size # Ensure block_size is properly set
# Register the custom configuration
CONFIG_MAPPING.register("custom_gpt", CustomGPTConfig)
# Wrapper for GPT to make it compatible with Hugging Face
class HuggingFaceGPT(PreTrainedModel):
config_class = CustomGPTConfig
def __init__(self, config):
super().__init__(config)
from nova_model import GPT # Replace with your actual model import
self.transformer = GPT(config)
def forward(self, input_ids, **kwargs):
targets = kwargs.get("labels", None)
logits, loss = self.transformer(input_ids, targets=targets)
return {"logits": logits, "loss": loss}
class EndpointHandler:
def __init__(self, model_dir, device="cuda"):
print(f"Initializing model from directory: {model_dir}")
# Load custom configuration and model
self.config = CustomGPTConfig.from_pretrained(model_dir)
self.model = HuggingFaceGPT(self.config)
state_dict = torch.load(f"{model_dir}/pytorch_model.bin", map_location=torch.device(device))
self.model.load_state_dict(state_dict)
self.model.to(device)
self.model.eval()
print("Model initialized successfully.")
# Load tokenizer
self.tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
self.device = device
print("Tokenizer loaded successfully.")
def __call__(self, inputs):
print("Processing inputs...")
# Extract inputs
prompt = inputs.get("inputs", "")
parameters = inputs.get("parameters", {})
max_length = parameters.get("max_length", 32)
num_return_sequences = parameters.get("num_return_sequences", 4)
temperature = parameters.get("temperature", 1.0)
top_k = parameters.get("top_k", 50)
if not prompt:
print("Error: Input prompt is missing.")
return [{"error": "Input prompt is missing"}]
print(f"Prompt: {prompt}")
print(f"Parameters: {parameters}")
# Encode input prompt
tokens = self.tokenizer.encode(prompt, return_tensors="pt").to(self.device)
tokens = tokens.repeat(num_return_sequences, 1)
# Prepare RNG for reproducibility
sample_rng = torch.Generator(device=self.device)
sample_rng.manual_seed(42)
# Initialize generation
generated_tokens = tokens
while generated_tokens.size(1) < max_length:
with torch.no_grad():
# Forward pass to get logits
output = self.model(input_ids=generated_tokens)
logits = output["logits"][:, -1, :] # Get the last token logits
# Apply softmax to get probabilities
probs = F.softmax(logits / temperature, dim=-1)
# Top-k sampling
topk_probs, topk_indices = torch.topk(probs, top_k, dim=-1)
next_token = torch.multinomial(topk_probs, 1, generator=sample_rng)
selected_token = torch.gather(topk_indices, -1, next_token)
# Append the generated token
generated_tokens = torch.cat((generated_tokens, selected_token), dim=1)
# Debug log for generation progress
print(f"Generated tokens so far: {generated_tokens.size(1)}/{max_length}")
# Decode and return generated text
results = []
for i in range(num_return_sequences):
tokens_list = generated_tokens[i, :max_length].tolist()
decoded_text = self.tokenizer.decode(tokens_list, skip_special_tokens=True)
results.append({"generated_text": decoded_text})
print("Generation completed.")
return results
if __name__ == "__main__":
# Example usage
model_directory = "./"
handler = EndpointHandler(model_directory)
prompt_text = "Hello, I'm a language model,"
inputs = {"inputs": prompt_text, "parameters": {"max_length": 32, "num_return_sequences": 4, "temperature": 0.7, "top_k": 50}}
print("Starting inference...")
outputs = handler(inputs)
for idx, result in enumerate(outputs):
print(f"Sample {idx}: {result['generated_text']}")