File size: 11,822 Bytes
4b9c5bf
 
 
 
d5ba468
 
51afe30
4b9c5bf
 
 
 
 
 
 
 
 
 
 
 
 
 
d5ba468
4b9c5bf
d5ba468
51afe30
d5ba468
 
 
 
 
 
 
4b9c5bf
 
d5ba468
 
 
 
 
 
 
 
 
 
 
51afe30
 
 
 
 
 
 
 
 
 
 
d5ba468
51afe30
d5ba468
51afe30
 
 
 
 
 
 
 
d5ba468
 
4b9c5bf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c7e22f6
4b9c5bf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c7e22f6
 
d5ba468
c7e22f6
d5ba468
 
c7e22f6
4b9c5bf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c7e22f6
4b9c5bf
 
 
 
c7e22f6
4b9c5bf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c7e22f6
4b9c5bf
 
 
 
 
 
 
 
 
 
c7e22f6
 
4b9c5bf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d5ba468
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4b9c5bf
 
d5ba468
4b9c5bf
d5ba468
4b9c5bf
d5ba468
4b9c5bf
 
 
 
d5ba468
4b9c5bf
 
 
 
 
 
 
 
 
 
 
d5ba468
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
import gradio as gr
import torch
import numpy as np
from transformers import AutoTokenizer, AutoModelForCausalLM
from sklearn.preprocessing import normalize
from umap import UMAP
import os
import json

# Model configuration
MODEL_ID = "openai-community/gpt2"
device = "cuda" if torch.cuda.is_available() else "cpu"

# Load model and tokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
model = AutoModelForCausalLM.from_pretrained(MODEL_ID, dtype=torch.float32).to(device).eval()

# Get vocabulary info
vocab_size = tokenizer.vocab_size if tokenizer.vocab_size is not None else len(tokenizer)
vocab_tokens = tokenizer.convert_ids_to_tokens(list(range(vocab_size)))

# Cache for embeddings and UMAP (computed once at startup)
embeddings_cache = None
umap_projections_cache = None
UMAP_LAYOUT_PATH = os.environ.get("UMAP_LAYOUT_PATH", os.path.join("assets", "umap_gpt2_cosine.npy"))
umap_params = {
    "n_neighbors": 75,
    "min_dist": 0.15,
    "metric": "cosine",
    "random_state": 0,
    "n_components": 2,
}

def initialize_embeddings():
    """Initialize embeddings and UMAP projections once at startup"""
    global embeddings_cache, umap_projections_cache

    # Get embeddings and cast to float32 for performance
    embeddings_cache = (
        model.get_input_embeddings().weight.detach().cpu().numpy().astype(np.float32)
    )  # [V, d]

    # Normalize rows (cosine distance works best with normalized vectors)
    norm_embeds = normalize(embeddings_cache, norm="l2", axis=1)

    # If a precomputed layout exists, load it
    try:
        if os.path.isfile(UMAP_LAYOUT_PATH):
            umap_projections_cache = np.load(UMAP_LAYOUT_PATH)
            # Basic validation
            if umap_projections_cache.shape[0] != norm_embeds.shape[0] or umap_projections_cache.shape[1] != 2:
                raise ValueError("Precomputed UMAP layout shape mismatch; recomputing.")
            return embeddings_cache, umap_projections_cache
    except Exception as e:
        # If load fails, fall through to recompute
        print(f"Warning: failed to load precomputed UMAP layout: {e}. Recomputing...")

    # UMAP to 2D (full vocab) — direct from normalized embeddings (no PCA)
    umap_model = UMAP(**umap_params)
    umap_projections_cache = umap_model.fit_transform(norm_embeds).astype(np.float32)  # [V, 2]

    # Save for future cold starts
    try:
        os.makedirs(os.path.dirname(UMAP_LAYOUT_PATH), exist_ok=True)
        np.save(UMAP_LAYOUT_PATH, umap_projections_cache)
    except Exception as e:
        print(f"Warning: failed to save UMAP layout to {UMAP_LAYOUT_PATH}: {e}")

    return embeddings_cache, umap_projections_cache

# Initialize embeddings at startup
initialize_embeddings()

def nice_tok(tok: str) -> str:
    """Clean up token display"""
    return tok.replace("Ġ", " ").replace("▁", " ")

@torch.no_grad()
def get_next_token_probs(text: str | None):
    """Get next token probabilities for given text context"""
    if text and len(text) > 0:
        enc = tokenizer(text, return_tensors="pt").to(device)
        out = model(**enc, return_dict=True)
        logits = out.logits[0, -1, :]
    else:
        # Unconditional: use BOS or EOS token
        bos_id = tokenizer.bos_token_id or tokenizer.eos_token_id
        if bos_id is None:
            raise ValueError("Tokenizer has neither BOS nor EOS.")
        input_ids = torch.tensor([[bos_id]], device=device)
        out = model(input_ids=input_ids, return_dict=True)
        logits = out.logits[0, -1, :]
    
    return logits

@torch.no_grad()
def predict_comprehensive(
    text: str,
    text2: str = "",
    top_k: int = 0,
    include_embeddings: bool = False,
    include_layout: bool = False,
    include_unconditional: bool = False,
    use_logprobs: bool = True
):
    """
    Comprehensive prediction endpoint that returns token probabilities,
    embeddings, and PCA projections for visualization
    """
    
    result = {
        "model": MODEL_ID,
        "vocab_size": vocab_size,
        "device": device
    }
    
    # Get primary text probabilities
    logits = get_next_token_probs(text)
    probs = torch.softmax(logits, dim=-1).cpu().numpy()
    
    if use_logprobs:
        log_probs = torch.log_softmax(logits, dim=-1).cpu().numpy()
    else:
        log_probs = None
    
    # Always include vocabulary tokens (cleaned)
    result["vocab"] = {
        "tokens": [nice_tok(t) for t in vocab_tokens],
        "raw_tokens": vocab_tokens,
        "size": vocab_size
    }
    
    # Primary context probabilities
    result["probs"] = probs.tolist()
    if log_probs is not None:
        result["logprobs"] = log_probs.tolist()
    
    # Top-k tokens if requested
    if top_k and top_k > 0:
        vals, idxs = torch.topk(torch.from_numpy(probs), k=min(top_k, len(probs)))
        idxs = idxs.tolist()
        vals = vals.tolist()
        result["topk"] = {
            "ids": idxs,
            "tokens": [nice_tok(vocab_tokens[i]) for i in idxs],
            "probs": vals
        }
        if log_probs is not None:
            result["topk"]["logprobs"] = [float(log_probs[i]) for i in idxs]
    
    # Second text context if provided
    if text2 and len(text2) > 0:
        logits2 = get_next_token_probs(text2)
        probs2 = torch.softmax(logits2, dim=-1).cpu().numpy()
        result["probs2"] = probs2.tolist()
        
        if use_logprobs:
            log_probs2 = torch.log_softmax(logits2, dim=-1).cpu().numpy()
            result["logprobs2"] = log_probs2.tolist()
    
    # Unconditional probabilities
    if include_unconditional:
        logits_uncond = get_next_token_probs(None)
        probs_uncond = torch.softmax(logits_uncond, dim=-1).cpu().numpy()
        result["unconditional_probs"] = probs_uncond.tolist()
        
        if use_logprobs:
            log_probs_uncond = torch.log_softmax(logits_uncond, dim=-1).cpu().numpy()
            result["unconditional_logprobs"] = log_probs_uncond.tolist()
    
    # Embeddings if requested
    if include_embeddings:
        # Return first 10 dimensions as sample (full embeddings would be too large)
        result["embeddings_sample"] = {
            "shape": list(embeddings_cache.shape),
            "first_10_tokens_sample": embeddings_cache[:10, :10].tolist() if embeddings_cache is not None else None
        }
    
    # UMAP layout if requested
    if include_layout:
        if umap_projections_cache is not None:
            result["layout"] = {
                "method": "umap",
                "umap_params": umap_params,
                "projections": umap_projections_cache.tolist(),
            }
    
    return result

@torch.no_grad()
def get_embeddings_endpoint(token_ids: str = ""):
    """
    Get embeddings for specific token IDs
    """
    if token_ids:
        try:
            ids = [int(x.strip()) for x in token_ids.split(",")]
            ids = [i for i in ids if 0 <= i < vocab_size]
        except:
            ids = list(range(min(100, vocab_size)))  # Default to first 100
    else:
        ids = list(range(min(100, vocab_size)))  # Default to first 100
    
    embeddings_subset = embeddings_cache[ids]
    tokens_subset = [nice_tok(vocab_tokens[i]) for i in ids]
    
    return {
        "token_ids": ids,
        "tokens": tokens_subset,
        "embeddings": embeddings_subset.tolist(),
        "embedding_dim": embeddings_cache.shape[1],
        "total_vocab_size": vocab_size
    }

# Create Gradio interface
with gr.Blocks(title="Token Probability Visualization API") as demo:
    gr.Markdown("""
    # Token Probability Visualization API
    
    This API provides token probabilities, embeddings, and a precomputed UMAP layout for visualization.
    Model: `openai-community/gpt2`
    """)
    
    with gr.Tab("Comprehensive API"):
        gr.Markdown("### Get token probabilities with optional embeddings and 2D layout")
        
        with gr.Row():
            with gr.Column():
                text_input = gr.Textbox(
                    label="Primary Context",
                    value="You are an expert in medieval history.",
                    lines=3
                )
                text2_input = gr.Textbox(
                    label="Secondary Context (optional)",
                    placeholder="Enter second text for comparison",
                    lines=3
                )
                
                with gr.Row():
                    top_k_slider = gr.Slider(0, 200, step=1, value=20, label="Top-K tokens")
                    include_embeddings = gr.Checkbox(False, label="Include Embeddings Sample")
                    include_layout = gr.Checkbox(False, label="Include 2D Layout (UMAP)")
                    include_unconditional = gr.Checkbox(True, label="Include Unconditional")
                    use_logprobs = gr.Checkbox(True, label="Include Log Probabilities")
                
                predict_btn = gr.Button("Get Predictions", variant="primary")
            
            with gr.Column():
                output_json = gr.JSON(label="API Response")
        
        predict_btn.click(
            fn=predict_comprehensive,
            inputs=[text_input, text2_input, top_k_slider, include_embeddings,
                   include_layout, include_unconditional, use_logprobs],
            outputs=output_json,
            api_name="predict"
        )
    
    with gr.Tab("Embeddings API"):
        gr.Markdown("### Get embeddings for specific tokens")
        
        with gr.Row():
            with gr.Column():
                token_ids_input = gr.Textbox(
                    label="Token IDs (comma-separated)",
                    placeholder="e.g., 0,1,2,3,4 or leave empty for first 100",
                    value="0,1,2,3,4,5,6,7,8,9"
                )
                get_embeddings_btn = gr.Button("Get Embeddings", variant="primary")
            
            with gr.Column():
                embeddings_output = gr.JSON(label="Embeddings Response")
        
        get_embeddings_btn.click(
            fn=get_embeddings_endpoint,
            inputs=token_ids_input,
            outputs=embeddings_output,
            api_name="embeddings"
        )
    
    with gr.Tab("Layout API"):
        gr.Markdown("### Get 2D token layout (UMAP) once and cache client-side")
        with gr.Row():
            get_layout_btn = gr.Button("Get UMAP Layout", variant="primary")
            layout_output = gr.JSON(label="UMAP Layout Response")

        def get_layout_endpoint():
            return {
                "method": "umap",
                "umap_params": umap_params,
                "tokens": [nice_tok(t) for t in vocab_tokens],
                "projections": umap_projections_cache.tolist() if umap_projections_cache is not None else None,
            }

        get_layout_btn.click(
            fn=get_layout_endpoint,
            inputs=None,
            outputs=layout_output,
            api_name="layout",
        )

    gr.Markdown("""
    ## API Usage

    ### Endpoints:
    - `/predict`: Main endpoint for token probabilities with optional embeddings and UMAP
    - `/embeddings`: Get embeddings for specific token IDs
    - `/layout`: Get 2D UMAP projections (fetch once, reuse client-side)
    
    ### Response includes:
    - Token probabilities (conditional and unconditional)
    - Log probabilities
    - UMAP projections for all vocabulary tokens (if requested or via `/layout`)
    - Token embeddings (sample or specific)
    - Vocabulary mappings
    
    ### Use this data to:
    - Visualize token probability landscapes
    - Compare conditional vs unconditional distributions
    - Create embedding visualizations
    - Build interactive token explorers
    """)

if __name__ == "__main__":
    demo.launch(server_name="0.0.0.0", server_port=7860, share=False)