neon-q12-v1-250M / handler.py
MarkChenX's picture
Update handler.py
573941c verified
import torch
import tiktoken
from model import GPT, GPTConfig
class EndpointHandler:
def __init__(self, path=""):
print("Loading GPT + Qiskit model...")
# Initialize model config and architecture
self.config = GPTConfig()
self.model = GPT(self.config)
# Load checkpoint
checkpoint_path = f"{path}/ckpt.pt"
checkpoint = torch.load(checkpoint_path, map_location="cpu")
# Extract state_dict if wrapped
if isinstance(checkpoint, dict) and "model" in checkpoint:
state_dict = checkpoint["model"]
else:
state_dict = checkpoint
# Strip any unwanted prefix (e.g., '_orig_mod.') from keys
cleaned_state_dict = {}
prefix = '_orig_mod.'
for key, val in state_dict.items():
new_key = key[len(prefix):] if key.startswith(prefix) else key
cleaned_state_dict[new_key] = val
# Load state dict non-strict to inspect mismatches
missing, unexpected = self.model.load_state_dict(cleaned_state_dict, strict=False)
if missing:
print("Warning: missing keys in state_dict:", missing)
if unexpected:
print("Warning: unexpected keys in state_dict:", unexpected)
# Ready model
self.model.eval()
# Initialize tokenizer for text inputs
self.tokenizer = tiktoken.get_encoding("gpt2")
print("Model loaded and ready.")
def __call__(self, data):
"""
Accept either:
- A raw prompt string (data is str)
- A dict: {"inputs": "prompt text"}
- A dict: {"inputs": {"input_ids": [[...]]}}
Returns:
{"generated_ids": [[...]], optional "generated_text": str}
"""
try:
# Determine input format
if isinstance(data, str):
text = data
elif isinstance(data, dict):
inputs = data.get("inputs")
if isinstance(inputs, str):
text = inputs
elif isinstance(inputs, dict) and "input_ids" in inputs:
input_ids = inputs["input_ids"]
else:
return {"error": "Invalid 'inputs'; expected string or dict with 'input_ids'"}
else:
return {"error": "Invalid request format"}
# If text prompt given, tokenize
if 'text' in locals():
# encode text into token IDs
tokens = self.tokenizer.encode(text)
input_ids = [tokens]
# Convert to tensor
input_tensor = torch.tensor(input_ids).long()
# Generate
with torch.no_grad():
output_tensor = self.model.generate(input_tensor, max_new_tokens=32)
output_ids = output_tensor.tolist()
# Build response
result = {"generated_ids": output_ids}
if 'text' in locals():
# Decode the first sequence
generated_tokens = output_ids[0]
try:
generated_text = self.tokenizer.decode(generated_tokens)
except Exception:
generated_text = None
if generated_text is not None:
result["generated_text"] = generated_text
return result
except Exception as e:
return {"error": str(e)}