Spaces:
Running
Running
| import gradio as gr | |
| import torch | |
| import numpy as np | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| from sklearn.preprocessing import normalize | |
| from umap import UMAP | |
| import os | |
| import json | |
| # Model configuration | |
| MODEL_ID = "openai-community/gpt2" | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| # Load model and tokenizer | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) | |
| model = AutoModelForCausalLM.from_pretrained(MODEL_ID, dtype=torch.float32).to(device).eval() | |
| # Get vocabulary info | |
| vocab_size = tokenizer.vocab_size if tokenizer.vocab_size is not None else len(tokenizer) | |
| vocab_tokens = tokenizer.convert_ids_to_tokens(list(range(vocab_size))) | |
| # Cache for embeddings and UMAP (computed once at startup) | |
| embeddings_cache = None | |
| umap_projections_cache = None | |
| UMAP_LAYOUT_PATH = os.environ.get("UMAP_LAYOUT_PATH", os.path.join("assets", "umap_gpt2_cosine.npy")) | |
| umap_params = { | |
| "n_neighbors": 75, | |
| "min_dist": 0.15, | |
| "metric": "cosine", | |
| "random_state": 0, | |
| "n_components": 2, | |
| } | |
| def initialize_embeddings(): | |
| """Initialize embeddings and UMAP projections once at startup""" | |
| global embeddings_cache, umap_projections_cache | |
| # Get embeddings and cast to float32 for performance | |
| embeddings_cache = ( | |
| model.get_input_embeddings().weight.detach().cpu().numpy().astype(np.float32) | |
| ) # [V, d] | |
| # Normalize rows (cosine distance works best with normalized vectors) | |
| norm_embeds = normalize(embeddings_cache, norm="l2", axis=1) | |
| # If a precomputed layout exists, load it | |
| try: | |
| if os.path.isfile(UMAP_LAYOUT_PATH): | |
| umap_projections_cache = np.load(UMAP_LAYOUT_PATH) | |
| # Basic validation | |
| if umap_projections_cache.shape[0] != norm_embeds.shape[0] or umap_projections_cache.shape[1] != 2: | |
| raise ValueError("Precomputed UMAP layout shape mismatch; recomputing.") | |
| return embeddings_cache, umap_projections_cache | |
| except Exception as e: | |
| # If load fails, fall through to recompute | |
| print(f"Warning: failed to load precomputed UMAP layout: {e}. Recomputing...") | |
| # UMAP to 2D (full vocab) — direct from normalized embeddings (no PCA) | |
| umap_model = UMAP(**umap_params) | |
| umap_projections_cache = umap_model.fit_transform(norm_embeds).astype(np.float32) # [V, 2] | |
| # Save for future cold starts | |
| try: | |
| os.makedirs(os.path.dirname(UMAP_LAYOUT_PATH), exist_ok=True) | |
| np.save(UMAP_LAYOUT_PATH, umap_projections_cache) | |
| except Exception as e: | |
| print(f"Warning: failed to save UMAP layout to {UMAP_LAYOUT_PATH}: {e}") | |
| return embeddings_cache, umap_projections_cache | |
| # Initialize embeddings at startup | |
| initialize_embeddings() | |
| def nice_tok(tok: str) -> str: | |
| """Clean up token display""" | |
| return tok.replace("Ġ", " ").replace("▁", " ") | |
| def get_next_token_probs(text: str | None): | |
| """Get next token probabilities for given text context""" | |
| if text and len(text) > 0: | |
| enc = tokenizer(text, return_tensors="pt").to(device) | |
| out = model(**enc, return_dict=True) | |
| logits = out.logits[0, -1, :] | |
| else: | |
| # Unconditional: use BOS or EOS token | |
| bos_id = tokenizer.bos_token_id or tokenizer.eos_token_id | |
| if bos_id is None: | |
| raise ValueError("Tokenizer has neither BOS nor EOS.") | |
| input_ids = torch.tensor([[bos_id]], device=device) | |
| out = model(input_ids=input_ids, return_dict=True) | |
| logits = out.logits[0, -1, :] | |
| return logits | |
| def predict_comprehensive( | |
| text: str, | |
| text2: str = "", | |
| top_k: int = 0, | |
| include_embeddings: bool = False, | |
| include_layout: bool = False, | |
| include_unconditional: bool = False, | |
| use_logprobs: bool = True | |
| ): | |
| """ | |
| Comprehensive prediction endpoint that returns token probabilities, | |
| embeddings, and PCA projections for visualization | |
| """ | |
| result = { | |
| "model": MODEL_ID, | |
| "vocab_size": vocab_size, | |
| "device": device | |
| } | |
| # Get primary text probabilities | |
| logits = get_next_token_probs(text) | |
| probs = torch.softmax(logits, dim=-1).cpu().numpy() | |
| if use_logprobs: | |
| log_probs = torch.log_softmax(logits, dim=-1).cpu().numpy() | |
| else: | |
| log_probs = None | |
| # Always include vocabulary tokens (cleaned) | |
| result["vocab"] = { | |
| "tokens": [nice_tok(t) for t in vocab_tokens], | |
| "raw_tokens": vocab_tokens, | |
| "size": vocab_size | |
| } | |
| # Primary context probabilities | |
| result["probs"] = probs.tolist() | |
| if log_probs is not None: | |
| result["logprobs"] = log_probs.tolist() | |
| # Top-k tokens if requested | |
| if top_k and top_k > 0: | |
| vals, idxs = torch.topk(torch.from_numpy(probs), k=min(top_k, len(probs))) | |
| idxs = idxs.tolist() | |
| vals = vals.tolist() | |
| result["topk"] = { | |
| "ids": idxs, | |
| "tokens": [nice_tok(vocab_tokens[i]) for i in idxs], | |
| "probs": vals | |
| } | |
| if log_probs is not None: | |
| result["topk"]["logprobs"] = [float(log_probs[i]) for i in idxs] | |
| # Second text context if provided | |
| if text2 and len(text2) > 0: | |
| logits2 = get_next_token_probs(text2) | |
| probs2 = torch.softmax(logits2, dim=-1).cpu().numpy() | |
| result["probs2"] = probs2.tolist() | |
| if use_logprobs: | |
| log_probs2 = torch.log_softmax(logits2, dim=-1).cpu().numpy() | |
| result["logprobs2"] = log_probs2.tolist() | |
| # Unconditional probabilities | |
| if include_unconditional: | |
| logits_uncond = get_next_token_probs(None) | |
| probs_uncond = torch.softmax(logits_uncond, dim=-1).cpu().numpy() | |
| result["unconditional_probs"] = probs_uncond.tolist() | |
| if use_logprobs: | |
| log_probs_uncond = torch.log_softmax(logits_uncond, dim=-1).cpu().numpy() | |
| result["unconditional_logprobs"] = log_probs_uncond.tolist() | |
| # Embeddings if requested | |
| if include_embeddings: | |
| # Return first 10 dimensions as sample (full embeddings would be too large) | |
| result["embeddings_sample"] = { | |
| "shape": list(embeddings_cache.shape), | |
| "first_10_tokens_sample": embeddings_cache[:10, :10].tolist() if embeddings_cache is not None else None | |
| } | |
| # UMAP layout if requested | |
| if include_layout: | |
| if umap_projections_cache is not None: | |
| result["layout"] = { | |
| "method": "umap", | |
| "umap_params": umap_params, | |
| "projections": umap_projections_cache.tolist(), | |
| } | |
| return result | |
| def get_embeddings_endpoint(token_ids: str = ""): | |
| """ | |
| Get embeddings for specific token IDs | |
| """ | |
| if token_ids: | |
| try: | |
| ids = [int(x.strip()) for x in token_ids.split(",")] | |
| ids = [i for i in ids if 0 <= i < vocab_size] | |
| except: | |
| ids = list(range(min(100, vocab_size))) # Default to first 100 | |
| else: | |
| ids = list(range(min(100, vocab_size))) # Default to first 100 | |
| embeddings_subset = embeddings_cache[ids] | |
| tokens_subset = [nice_tok(vocab_tokens[i]) for i in ids] | |
| return { | |
| "token_ids": ids, | |
| "tokens": tokens_subset, | |
| "embeddings": embeddings_subset.tolist(), | |
| "embedding_dim": embeddings_cache.shape[1], | |
| "total_vocab_size": vocab_size | |
| } | |
| # Create Gradio interface | |
| with gr.Blocks(title="Token Probability Visualization API") as demo: | |
| gr.Markdown(""" | |
| # Token Probability Visualization API | |
| This API provides token probabilities, embeddings, and a precomputed UMAP layout for visualization. | |
| Model: `openai-community/gpt2` | |
| """) | |
| with gr.Tab("Comprehensive API"): | |
| gr.Markdown("### Get token probabilities with optional embeddings and 2D layout") | |
| with gr.Row(): | |
| with gr.Column(): | |
| text_input = gr.Textbox( | |
| label="Primary Context", | |
| value="You are an expert in medieval history.", | |
| lines=3 | |
| ) | |
| text2_input = gr.Textbox( | |
| label="Secondary Context (optional)", | |
| placeholder="Enter second text for comparison", | |
| lines=3 | |
| ) | |
| with gr.Row(): | |
| top_k_slider = gr.Slider(0, 200, step=1, value=20, label="Top-K tokens") | |
| include_embeddings = gr.Checkbox(False, label="Include Embeddings Sample") | |
| include_layout = gr.Checkbox(False, label="Include 2D Layout (UMAP)") | |
| include_unconditional = gr.Checkbox(True, label="Include Unconditional") | |
| use_logprobs = gr.Checkbox(True, label="Include Log Probabilities") | |
| predict_btn = gr.Button("Get Predictions", variant="primary") | |
| with gr.Column(): | |
| output_json = gr.JSON(label="API Response") | |
| predict_btn.click( | |
| fn=predict_comprehensive, | |
| inputs=[text_input, text2_input, top_k_slider, include_embeddings, | |
| include_layout, include_unconditional, use_logprobs], | |
| outputs=output_json, | |
| api_name="predict" | |
| ) | |
| with gr.Tab("Embeddings API"): | |
| gr.Markdown("### Get embeddings for specific tokens") | |
| with gr.Row(): | |
| with gr.Column(): | |
| token_ids_input = gr.Textbox( | |
| label="Token IDs (comma-separated)", | |
| placeholder="e.g., 0,1,2,3,4 or leave empty for first 100", | |
| value="0,1,2,3,4,5,6,7,8,9" | |
| ) | |
| get_embeddings_btn = gr.Button("Get Embeddings", variant="primary") | |
| with gr.Column(): | |
| embeddings_output = gr.JSON(label="Embeddings Response") | |
| get_embeddings_btn.click( | |
| fn=get_embeddings_endpoint, | |
| inputs=token_ids_input, | |
| outputs=embeddings_output, | |
| api_name="embeddings" | |
| ) | |
| with gr.Tab("Layout API"): | |
| gr.Markdown("### Get 2D token layout (UMAP) once and cache client-side") | |
| with gr.Row(): | |
| get_layout_btn = gr.Button("Get UMAP Layout", variant="primary") | |
| layout_output = gr.JSON(label="UMAP Layout Response") | |
| def get_layout_endpoint(): | |
| return { | |
| "method": "umap", | |
| "umap_params": umap_params, | |
| "tokens": [nice_tok(t) for t in vocab_tokens], | |
| "projections": umap_projections_cache.tolist() if umap_projections_cache is not None else None, | |
| } | |
| get_layout_btn.click( | |
| fn=get_layout_endpoint, | |
| inputs=None, | |
| outputs=layout_output, | |
| api_name="layout", | |
| ) | |
| gr.Markdown(""" | |
| ## API Usage | |
| ### Endpoints: | |
| - `/predict`: Main endpoint for token probabilities with optional embeddings and UMAP | |
| - `/embeddings`: Get embeddings for specific token IDs | |
| - `/layout`: Get 2D UMAP projections (fetch once, reuse client-side) | |
| ### Response includes: | |
| - Token probabilities (conditional and unconditional) | |
| - Log probabilities | |
| - UMAP projections for all vocabulary tokens (if requested or via `/layout`) | |
| - Token embeddings (sample or specific) | |
| - Vocabulary mappings | |
| ### Use this data to: | |
| - Visualize token probability landscapes | |
| - Compare conditional vs unconditional distributions | |
| - Create embedding visualizations | |
| - Build interactive token explorers | |
| """) | |
| if __name__ == "__main__": | |
| demo.launch(server_name="0.0.0.0", server_port=7860, share=False) | |