File size: 3,458 Bytes
49a02a8 573941c 49a02a8 b3d8fec 573941c b3d8fec 31b59c5 573941c 31b59c5 b3d8fec 573941c 31b59c5 573941c b3d8fec 49a02a8 573941c 49a02a8 b3d8fec 573941c b3d8fec 573941c b3d8fec 573941c b3d8fec 49a02a8 573941c 49a02a8 b3d8fec | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 | import torch
import tiktoken
from model import GPT, GPTConfig
class EndpointHandler:
def __init__(self, path=""):
print("Loading GPT + Qiskit model...")
# Initialize model config and architecture
self.config = GPTConfig()
self.model = GPT(self.config)
# Load checkpoint
checkpoint_path = f"{path}/ckpt.pt"
checkpoint = torch.load(checkpoint_path, map_location="cpu")
# Extract state_dict if wrapped
if isinstance(checkpoint, dict) and "model" in checkpoint:
state_dict = checkpoint["model"]
else:
state_dict = checkpoint
# Strip any unwanted prefix (e.g., '_orig_mod.') from keys
cleaned_state_dict = {}
prefix = '_orig_mod.'
for key, val in state_dict.items():
new_key = key[len(prefix):] if key.startswith(prefix) else key
cleaned_state_dict[new_key] = val
# Load state dict non-strict to inspect mismatches
missing, unexpected = self.model.load_state_dict(cleaned_state_dict, strict=False)
if missing:
print("Warning: missing keys in state_dict:", missing)
if unexpected:
print("Warning: unexpected keys in state_dict:", unexpected)
# Ready model
self.model.eval()
# Initialize tokenizer for text inputs
self.tokenizer = tiktoken.get_encoding("gpt2")
print("Model loaded and ready.")
def __call__(self, data):
"""
Accept either:
- A raw prompt string (data is str)
- A dict: {"inputs": "prompt text"}
- A dict: {"inputs": {"input_ids": [[...]]}}
Returns:
{"generated_ids": [[...]], optional "generated_text": str}
"""
try:
# Determine input format
if isinstance(data, str):
text = data
elif isinstance(data, dict):
inputs = data.get("inputs")
if isinstance(inputs, str):
text = inputs
elif isinstance(inputs, dict) and "input_ids" in inputs:
input_ids = inputs["input_ids"]
else:
return {"error": "Invalid 'inputs'; expected string or dict with 'input_ids'"}
else:
return {"error": "Invalid request format"}
# If text prompt given, tokenize
if 'text' in locals():
# encode text into token IDs
tokens = self.tokenizer.encode(text)
input_ids = [tokens]
# Convert to tensor
input_tensor = torch.tensor(input_ids).long()
# Generate
with torch.no_grad():
output_tensor = self.model.generate(input_tensor, max_new_tokens=32)
output_ids = output_tensor.tolist()
# Build response
result = {"generated_ids": output_ids}
if 'text' in locals():
# Decode the first sequence
generated_tokens = output_ids[0]
try:
generated_text = self.tokenizer.decode(generated_tokens)
except Exception:
generated_text = None
if generated_text is not None:
result["generated_text"] = generated_text
return result
except Exception as e:
return {"error": str(e)}
|