entropy / app.py
ejschwartz's picture
Update app.py
ce736c4 verified
import os
import math
import spaces
import torch
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM
MODEL_ID = os.getenv("MODEL_ID", "Qwen/Qwen3-4B")
#MODEL_ID = os.getenv("MODEL_ID", "bigcode/starcoder2-3b")
def load_model():
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
# Ensure a pad token exists for safe batching; use eos if needed
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
torch_dtype="auto",
device_map="auto",
trust_remote_code=True,
)
model.eval()
return tokenizer, model
TOKENIZER, MODEL = load_model()
@spaces.GPU
def compute_entropy(code: str):
if not code or not code.strip():
return "Please paste some source code.", None
with torch.no_grad():
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if next(MODEL.parameters()).device != device:
MODEL.to(device)
enc = TOKENIZER(code, return_tensors="pt")
input_ids = enc["input_ids"]
attention_mask = enc.get("attention_mask")
input_ids = input_ids.to(device)
if attention_mask is not None:
attention_mask = attention_mask.to(device)
# Need at least 2 tokens to compute next-token NLL
if input_ids.shape[1] < 2:
return "Input is too short to compute token-level entropy.", None
outputs = MODEL(input_ids=input_ids, attention_mask=attention_mask)
logits = outputs.logits
# Shift for next-token prediction
shift_logits = logits[:, :-1, :]
shift_labels = input_ids[:, 1:]
log_probs = torch.log_softmax(shift_logits, dim=-1)
# Gather log prob of the true next token
true_log_probs = log_probs.gather(2, shift_labels.unsqueeze(-1)).squeeze(-1)
nll = -true_log_probs # negative log-likelihood (nats)
nll_list = nll.squeeze(0).detach().cpu().tolist()
label_ids = shift_labels.squeeze(0).detach().cpu().tolist()
tokens = TOKENIZER.convert_ids_to_tokens(label_ids)
rows = []
for tok, nll_val in zip(tokens, nll_list):
prob = math.exp(-nll_val)
rows.append([tok, float(nll_val), float(prob)])
avg_nll = sum(nll_list) / len(nll_list)
avg_bits = avg_nll / math.log(2)
# total entropy in bits is the sum of per-token NLL (bits)
total_bits = sum(nll_list) / math.log(2)
# total entropy in nats is simply the sum of per-token NLL (nats)
total_nats = sum(nll_list)
summary = (
f"Tokens evaluated: {len(nll_list)}\n"
f"Average NLL (nats): {avg_nll:.4f}\n"
f"Average NLL (bits): {avg_bits:.4f}\n"
f"Total entropy (nats): {total_nats:.4f}\n"
f"Total entropy (bits): {total_bits:.4f}"
)
return summary, rows
def build_app():
with gr.Blocks(title="Entropy for Source Code") as demo:
gr.Markdown(
f"""
# Source Code Entropy ({MODEL_ID})
Paste code below to compute token-level negative log-likelihood (NLL).
The table shows each token's NLL and probability under the model.
"""
)
code = gr.Textbox(
label="Source Code",
lines=16,
placeholder="Paste your source code here...",
)
btn = gr.Button("Compute Entropy")
summary = gr.Textbox(label="Summary", lines=4)
table = gr.Dataframe(
headers=["token", "nll_nats", "prob"],
datatype=["str", "number", "number"],
label="Token-level NLL",
)
btn.click(fn=compute_entropy, inputs=[code], outputs=[summary, table])
gr.Markdown(
"""
Notes:
- NLL is computed for next-token prediction and excludes the first token.
- Large inputs may take time to process depending on hardware.
"""
)
return demo
app = build_app()
if __name__ == "__main__":
app.launch()