| | import torch |
| | import tiktoken |
| | from model import GPT, GPTConfig |
| |
|
| | class EndpointHandler: |
| | def __init__(self, path=""): |
| | print("Loading GPT + Qiskit model...") |
| |
|
| | |
| | self.config = GPTConfig() |
| | self.model = GPT(self.config) |
| |
|
| | |
| | checkpoint_path = f"{path}/ckpt.pt" |
| | checkpoint = torch.load(checkpoint_path, map_location="cpu") |
| |
|
| | |
| | if isinstance(checkpoint, dict) and "model" in checkpoint: |
| | state_dict = checkpoint["model"] |
| | else: |
| | state_dict = checkpoint |
| |
|
| | |
| | cleaned_state_dict = {} |
| | prefix = '_orig_mod.' |
| | for key, val in state_dict.items(): |
| | new_key = key[len(prefix):] if key.startswith(prefix) else key |
| | cleaned_state_dict[new_key] = val |
| |
|
| | |
| | missing, unexpected = self.model.load_state_dict(cleaned_state_dict, strict=False) |
| | if missing: |
| | print("Warning: missing keys in state_dict:", missing) |
| | if unexpected: |
| | print("Warning: unexpected keys in state_dict:", unexpected) |
| |
|
| | |
| | self.model.eval() |
| | |
| | self.tokenizer = tiktoken.get_encoding("gpt2") |
| |
|
| | print("Model loaded and ready.") |
| |
|
| | def __call__(self, data): |
| | """ |
| | Accept either: |
| | - A raw prompt string (data is str) |
| | - A dict: {"inputs": "prompt text"} |
| | - A dict: {"inputs": {"input_ids": [[...]]}} |
| | Returns: |
| | {"generated_ids": [[...]], optional "generated_text": str} |
| | """ |
| | try: |
| | |
| | if isinstance(data, str): |
| | text = data |
| | elif isinstance(data, dict): |
| | inputs = data.get("inputs") |
| | if isinstance(inputs, str): |
| | text = inputs |
| | elif isinstance(inputs, dict) and "input_ids" in inputs: |
| | input_ids = inputs["input_ids"] |
| | else: |
| | return {"error": "Invalid 'inputs'; expected string or dict with 'input_ids'"} |
| | else: |
| | return {"error": "Invalid request format"} |
| |
|
| | |
| | if 'text' in locals(): |
| | |
| | tokens = self.tokenizer.encode(text) |
| | input_ids = [tokens] |
| |
|
| | |
| | input_tensor = torch.tensor(input_ids).long() |
| |
|
| | |
| | with torch.no_grad(): |
| | output_tensor = self.model.generate(input_tensor, max_new_tokens=32) |
| | output_ids = output_tensor.tolist() |
| |
|
| | |
| | result = {"generated_ids": output_ids} |
| | if 'text' in locals(): |
| | |
| | generated_tokens = output_ids[0] |
| | try: |
| | generated_text = self.tokenizer.decode(generated_tokens) |
| | except Exception: |
| | generated_text = None |
| | if generated_text is not None: |
| | result["generated_text"] = generated_text |
| |
|
| | return result |
| |
|
| | except Exception as e: |
| | return {"error": str(e)} |
| |
|