LLM / app.py
Tristan
Add interactive token visualization with hover tooltips showing top-5 alternatives
b779be4
from fastapi import FastAPI
from fastapi.responses import HTMLResponse
from fastapi.staticfiles import StaticFiles
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
import torch
import uvicorn
import os
app = FastAPI()
# Load models and tokenizer
model_name = "Qwen/Qwen2.5-0.5B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)
generator_pipe = pipeline("text-generation", model=model_name, tokenizer=tokenizer)
summarizer_pipe = pipeline("summarization", model="facebook/bart-large-cnn")
app.mount("/static", StaticFiles(directory="static"), name="static")
class GenRequest(BaseModel):
text: str
max_new_tokens: int = 150
do_sample: bool = False
mode: str = "generate" # "generate" or "summarize"
@app.get("/", response_class=HTMLResponse)
async def read_root():
with open("templates/index.html", "r") as f:
return f.read()
@app.post("/generate")
def generate(req: GenRequest):
if req.mode == "summarize":
# Use summarization pipeline
out = summarizer_pipe(
req.text,
max_length=req.max_new_tokens,
min_length=30,
do_sample=req.do_sample,
)
return {"generated_text": out[0]["summary_text"]}
else:
# Use text generation pipeline with token-level alternatives
return generate_with_alternatives(req)
def generate_with_alternatives(req: GenRequest):
"""Generate text token-by-token with top-5 alternatives for each token"""
input_text = req.text
max_new_tokens = req.max_new_tokens
tokens_data = []
current_text = input_text
for _ in range(max_new_tokens):
inputs = tokenizer(current_text, return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs)
next_token_logits = outputs.logits[0, -1, :]
# Get probabilities
probs = torch.softmax(next_token_logits, dim=-1)
# Get top 5 alternatives
top_k = 5
top_probs, top_indices = torch.topk(probs, top_k)
# Choose the greedy token (highest probability)
chosen_token_id = top_indices[0].item()
chosen_token = tokenizer.decode([chosen_token_id])
# Collect alternatives
alternatives = []
for i in range(top_k):
token_id = top_indices[i].item()
token_text = tokenizer.decode([token_id])
probability = top_probs[i].item() * 100
alternatives.append({
"token": token_text,
"probability": round(probability, 2)
})
tokens_data.append({
"token": chosen_token,
"alternatives": alternatives
})
# Update current text with chosen token
current_text += chosen_token
# Check for end of sequence
if chosen_token_id == tokenizer.eos_token_id:
break
# Reconstruct full text
generated_text = "".join([t["token"] for t in tokens_data])
return {
"generated_text": generated_text,
"tokens": tokens_data
}
@app.post("/predict_next")
def predict_next(req: GenRequest):
"""Get top predictions for next word/token"""
inputs = tokenizer(req.text, return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs)
next_token_logits = outputs.logits[0, -1, :]
# Get top 10 predictions
top_k = 10
probs = torch.softmax(next_token_logits, dim=-1)
top_probs, top_indices = torch.topk(probs, top_k)
predictions = []
for prob, idx in zip(top_probs.tolist(), top_indices.tolist()):
token = tokenizer.decode([idx])
predictions.append({
"token": token,
"probability": round(prob * 100, 2)
})
return {"predictions": predictions}
if __name__ == "__main__":
port = int(os.environ.get("PORT", 7860))
uvicorn.run(app, host="0.0.0.0", port=port)