File size: 3,458 Bytes
49a02a8
573941c
49a02a8
 
 
 
b3d8fec
 
 
 
 
 
 
 
 
 
573941c
b3d8fec
 
 
 
 
31b59c5
 
 
 
 
 
 
573941c
31b59c5
 
 
 
 
b3d8fec
573941c
31b59c5
573941c
 
 
b3d8fec
49a02a8
 
 
573941c
 
 
 
 
 
49a02a8
b3d8fec
573941c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b3d8fec
573941c
b3d8fec
 
573941c
b3d8fec
 
 
49a02a8
573941c
 
 
 
 
 
 
 
 
 
 
 
 
49a02a8
b3d8fec
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import torch
import tiktoken
from model import GPT, GPTConfig

class EndpointHandler:
    def __init__(self, path=""):
        print("Loading GPT + Qiskit model...")

        # Initialize model config and architecture
        self.config = GPTConfig()
        self.model = GPT(self.config)

        # Load checkpoint
        checkpoint_path = f"{path}/ckpt.pt"
        checkpoint = torch.load(checkpoint_path, map_location="cpu")

        # Extract state_dict if wrapped
        if isinstance(checkpoint, dict) and "model" in checkpoint:
            state_dict = checkpoint["model"]
        else:
            state_dict = checkpoint

        # Strip any unwanted prefix (e.g., '_orig_mod.') from keys
        cleaned_state_dict = {}
        prefix = '_orig_mod.'
        for key, val in state_dict.items():
            new_key = key[len(prefix):] if key.startswith(prefix) else key
            cleaned_state_dict[new_key] = val

        # Load state dict non-strict to inspect mismatches
        missing, unexpected = self.model.load_state_dict(cleaned_state_dict, strict=False)
        if missing:
            print("Warning: missing keys in state_dict:", missing)
        if unexpected:
            print("Warning: unexpected keys in state_dict:", unexpected)

        # Ready model
        self.model.eval()
        # Initialize tokenizer for text inputs
        self.tokenizer = tiktoken.get_encoding("gpt2")

        print("Model loaded and ready.")

    def __call__(self, data):
        """
        Accept either:
          - A raw prompt string (data is str)
          - A dict: {"inputs": "prompt text"}
          - A dict: {"inputs": {"input_ids": [[...]]}}
        Returns:
          {"generated_ids": [[...]], optional "generated_text": str}
        """
        try:
            # Determine input format
            if isinstance(data, str):
                text = data
            elif isinstance(data, dict):
                inputs = data.get("inputs")
                if isinstance(inputs, str):
                    text = inputs
                elif isinstance(inputs, dict) and "input_ids" in inputs:
                    input_ids = inputs["input_ids"]
                else:
                    return {"error": "Invalid 'inputs'; expected string or dict with 'input_ids'"}
            else:
                return {"error": "Invalid request format"}

            # If text prompt given, tokenize
            if 'text' in locals():
                # encode text into token IDs
                tokens = self.tokenizer.encode(text)
                input_ids = [tokens]

            # Convert to tensor
            input_tensor = torch.tensor(input_ids).long()

            # Generate
            with torch.no_grad():
                output_tensor = self.model.generate(input_tensor, max_new_tokens=32)
                output_ids = output_tensor.tolist()

            # Build response
            result = {"generated_ids": output_ids}
            if 'text' in locals():
                # Decode the first sequence
                generated_tokens = output_ids[0]
                try:
                    generated_text = self.tokenizer.decode(generated_tokens)
                except Exception:
                    generated_text = None
                if generated_text is not None:
                    result["generated_text"] = generated_text

            return result

        except Exception as e:
            return {"error": str(e)}