import torch import tiktoken from model import GPT, GPTConfig class EndpointHandler: def __init__(self, path=""): print("Loading GPT + Qiskit model...") # Initialize model config and architecture self.config = GPTConfig() self.model = GPT(self.config) # Load checkpoint checkpoint_path = f"{path}/ckpt.pt" checkpoint = torch.load(checkpoint_path, map_location="cpu") # Extract state_dict if wrapped if isinstance(checkpoint, dict) and "model" in checkpoint: state_dict = checkpoint["model"] else: state_dict = checkpoint # Strip any unwanted prefix (e.g., '_orig_mod.') from keys cleaned_state_dict = {} prefix = '_orig_mod.' for key, val in state_dict.items(): new_key = key[len(prefix):] if key.startswith(prefix) else key cleaned_state_dict[new_key] = val # Load state dict non-strict to inspect mismatches missing, unexpected = self.model.load_state_dict(cleaned_state_dict, strict=False) if missing: print("Warning: missing keys in state_dict:", missing) if unexpected: print("Warning: unexpected keys in state_dict:", unexpected) # Ready model self.model.eval() # Initialize tokenizer for text inputs self.tokenizer = tiktoken.get_encoding("gpt2") print("Model loaded and ready.") def __call__(self, data): """ Accept either: - A raw prompt string (data is str) - A dict: {"inputs": "prompt text"} - A dict: {"inputs": {"input_ids": [[...]]}} Returns: {"generated_ids": [[...]], optional "generated_text": str} """ try: # Determine input format if isinstance(data, str): text = data elif isinstance(data, dict): inputs = data.get("inputs") if isinstance(inputs, str): text = inputs elif isinstance(inputs, dict) and "input_ids" in inputs: input_ids = inputs["input_ids"] else: return {"error": "Invalid 'inputs'; expected string or dict with 'input_ids'"} else: return {"error": "Invalid request format"} # If text prompt given, tokenize if 'text' in locals(): # encode text into token IDs tokens = self.tokenizer.encode(text) input_ids = [tokens] # Convert to tensor input_tensor = torch.tensor(input_ids).long() # Generate with torch.no_grad(): output_tensor = self.model.generate(input_tensor, max_new_tokens=32) output_ids = output_tensor.tolist() # Build response result = {"generated_ids": output_ids} if 'text' in locals(): # Decode the first sequence generated_tokens = output_ids[0] try: generated_text = self.tokenizer.decode(generated_tokens) except Exception: generated_text = None if generated_text is not None: result["generated_text"] = generated_text return result except Exception as e: return {"error": str(e)}