Andres37062's picture
Increase probability precision to 10 decimal places
50cf1fb verified
from fastapi import FastAPI, Request
from fastapi.responses import HTMLResponse, JSONResponse
from pydantic import BaseModel
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
import torch
import torch.nn.functional as F
import asyncio
from datetime import datetime
from typing import List, Dict
# Initialize FastAPI app
app = FastAPI(title="LLM Inference Server")
# Determine device (CPU for free tier)
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
# Request queues
chat_queue = asyncio.Queue()
summarize_queue = asyncio.Queue()
queue_stats = {"chat": 0, "summarize": 0}
# Load models
print("Loading Qwen model for chat...")
chat_model = AutoModelForCausalLM.from_pretrained(
"Qwen/Qwen2.5-0.5B-Instruct",
torch_dtype=torch.float32 if device == "cpu" else torch.float16,
device_map=device
)
chat_tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
print("Loading summarization model...")
summarizer = pipeline(
"summarization",
model="Falconsai/text_summarization",
device=device
)
print("Models loaded successfully!")
# Request models
class ChatRequest(BaseModel):
message: str
class SummarizeRequest(BaseModel):
text: str
def generate_with_token_probs(model, tokenizer, messages: List[Dict], max_new_tokens: int = 512):
"""
Generate text and capture top-5 token predictions with probabilities for each step.
"""
# Apply chat template
text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
inputs = tokenizer(text, return_tensors="pt").to(model.device)
# Generate with scores
outputs = model.generate(
**inputs,
max_new_tokens=max_new_tokens,
temperature=0.7,
do_sample=True,
output_scores=True,
return_dict_in_generate=True,
pad_token_id=tokenizer.eos_token_id
)
# Extract generated tokens (excluding input)
generated_tokens = outputs.sequences[0][inputs.input_ids.shape[1]:]
# Process scores to get top-5 predictions for each token
token_data = []
for idx, score in enumerate(outputs.scores):
# Apply softmax to get probabilities
probs = F.softmax(score[0], dim=-1)
# Get top 5 predictions
top5_probs, top5_indices = torch.topk(probs, k=5)
# Decode tokens
top5_tokens = [tokenizer.decode([token_id]) for token_id in top5_indices]
top5_probs_list = [float(prob) * 100 for prob in top5_probs] # Convert to percentage
# Build alternatives list
alternatives = []
for token, prob in zip(top5_tokens, top5_probs_list):
alternatives.append({
"token": token,
"probability": round(prob, 10)
})
# Get the actual generated token
actual_token = tokenizer.decode([generated_tokens[idx]])
token_data.append({
"token": actual_token,
"top5": alternatives
})
# Build full response text
full_response = tokenizer.decode(generated_tokens, skip_special_tokens=True)
return {
"response": full_response,
"tokens": token_data
}
# Queue processor for chat with token probabilities
async def process_chat_queue():
while True:
request_data = await chat_queue.get()
try:
messages = [{"role": "user", "content": request_data["message"]}]
result = generate_with_token_probs(
chat_model,
chat_tokenizer,
messages,
max_new_tokens=512
)
request_data["result"] = result
except Exception as e:
request_data["result"] = {"error": str(e)}
finally:
queue_stats["chat"] = max(0, queue_stats["chat"] - 1)
chat_queue.task_done()
# Queue processor for summarization
async def process_summarize_queue():
while True:
request_data = await summarize_queue.get()
try:
summary = summarizer(
request_data["text"],
max_length=130,
min_length=30,
do_sample=False
)
request_data["result"] = {"summary": summary[0]["summary_text"]}
except Exception as e:
request_data["result"] = {"error": str(e)}
finally:
queue_stats["summarize"] = max(0, queue_stats["summarize"] - 1)
summarize_queue.task_done()
# Start queue processors on startup
@app.on_event("startup")
async def startup_event():
asyncio.create_task(process_chat_queue())
asyncio.create_task(process_summarize_queue())
# Custom HTML UI
@app.get("/", response_class=HTMLResponse)
async def root():
return """
<!DOCTYPE html>
<html>
<head>
<title>LLM Inference Server</title>
<style>
* {
margin: 0;
padding: 0;
box-sizing: border-box;
}
body {
font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
min-height: 100vh;
display: flex;
justify-content: center;
align-items: center;
padding: 20px;
}
.container {
background: white;
border-radius: 20px;
box-shadow: 0 20px 60px rgba(0,0,0,0.3);
max-width: 800px;
width: 100%;
overflow: hidden;
}
.header {
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
color: white;
padding: 30px;
text-align: center;
}
h1 {
font-size: 2.5em;
margin-bottom: 10px;
}
.subtitle {
opacity: 0.9;
font-size: 1.1em;
}
.tabs {
display: flex;
background: #f5f5f5;
border-bottom: 2px solid #e0e0e0;
}
.tab {
flex: 1;
padding: 20px;
text-align: center;
cursor: pointer;
font-weight: 600;
transition: all 0.3s;
background: #f5f5f5;
border: none;
font-size: 1.1em;
color: black;
}
.tab:hover {
background: #e8e8e8;
}
.tab.active {
background: white;
color: #667eea;
border-bottom: 3px solid #667eea;
}
.content {
padding: 30px;
}
.tab-content {
display: none;
}
.tab-content.active {
display: block;
}
textarea {
width: 100%;
padding: 15px;
border: 2px solid #e0e0e0;
border-radius: 10px;
font-size: 1em;
font-family: inherit;
resize: vertical;
transition: border-color 0.3s;
}
textarea:focus {
outline: none;
border-color: #667eea;
}
button {
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
color: white;
border: none;
padding: 15px 40px;
font-size: 1.1em;
border-radius: 10px;
cursor: pointer;
margin-top: 20px;
transition: transform 0.2s, box-shadow 0.2s;
font-weight: 600;
}
button:hover {
transform: translateY(-2px);
box-shadow: 0 5px 20px rgba(102, 126, 234, 0.4);
}
button:active {
transform: translateY(0);
}
button:disabled {
opacity: 0.6;
cursor: not-allowed;
transform: none;
}
.response {
margin-top: 20px;
padding: 20px;
background: #f9f9f9;
border-radius: 10px;
border-left: 4px solid #667eea;
white-space: pre-wrap;
word-wrap: break-word;
display: none;
line-height: 1.6;
}
.response.show {
display: block;
}
/* Token visualization styles */
.token {
display: inline;
position: relative;
cursor: help;
padding: 2px 1px;
transition: background-color 0.2s;
}
.token:hover {
background-color: #fff3cd;
border-radius: 3px;
}
.tooltip {
visibility: hidden;
position: absolute;
bottom: 125%;
left: 50%;
transform: translateX(-50%);
background-color: #333;
color: white;
padding: 12px;
border-radius: 8px;
font-size: 0.85em;
white-space: nowrap;
z-index: 1000;
box-shadow: 0 4px 12px rgba(0,0,0,0.3);
min-width: 200px;
}
.tooltip::after {
content: "";
position: absolute;
top: 100%;
left: 50%;
transform: translateX(-50%);
border-width: 6px;
border-style: solid;
border-color: #333 transparent transparent transparent;
}
.token:hover .tooltip {
visibility: visible;
opacity: 1;
}
.tooltip-item {
display: flex;
justify-content: space-between;
padding: 3px 0;
border-bottom: 1px solid rgba(255,255,255,0.1);
}
.tooltip-item:last-child {
border-bottom: none;
}
.tooltip-rank {
color: #ffd700;
font-weight: 600;
margin-right: 8px;
}
.tooltip-token {
font-family: monospace;
color: #fff;
margin-right: 8px;
}
.tooltip-prob {
color: #90ee90;
font-weight: 500;
}
.loading {
display: inline-block;
width: 20px;
height: 20px;
border: 3px solid rgba(255,255,255,.3);
border-radius: 50%;
border-top-color: white;
animation: spin 1s ease-in-out infinite;
}
@keyframes spin {
to { transform: rotate(360deg); }
}
.error {
color: #d32f2f;
background: #ffebee;
border-left-color: #d32f2f;
}
.queue-info {
margin-top: 10px;
padding: 10px;
background: #e3f2fd;
border-radius: 5px;
font-size: 0.9em;
color: #1976d2;
display: none;
}
.queue-info.show {
display: block;
}
</style>
</head>
<body>
<div class="container">
<div class="header">
<h1>🤖 LLM Inference Server</h1>
<p class="subtitle">Powered by Qwen & Falconsai</p>
</div>
<div class="tabs">
<button class="tab active" onclick="switchTab('chat')">💬 Chat</button>
<button class="tab" onclick="switchTab('summarize')">📝 Summarize</button>
</div>
<div class="content">
<div id="chat" class="tab-content active">
<h2 style="margin-bottom: 15px;">Chat with Qwen</h2>
<textarea id="chatInput" rows="4" placeholder="Type your message here..."></textarea>
<button onclick="sendChat()">Send Message</button>
<div id="chatQueue" class="queue-info"></div>
<div id="chatResponse" class="response"></div>
</div>
<div id="summarize" class="tab-content">
<h2 style="margin-bottom: 15px;">Summarize Text</h2>
<textarea id="summarizeInput" rows="8" placeholder="Paste the text you want to summarize..."></textarea>
<button onclick="sendSummarize()">Generate Summary</button>
<div id="summarizeQueue" class="queue-info"></div>
<div id="summarizeResponse" class="response"></div>
</div>
</div>
</div>
<script>
function switchTab(tabName) {
// Hide all tab contents
document.querySelectorAll('.tab-content').forEach(content => {
content.classList.remove('active');
});
// Remove active class from all tabs
document.querySelectorAll('.tab').forEach(tab => {
tab.classList.remove('active');
});
// Show selected tab content
document.getElementById(tabName).classList.add('active');
// Add active class to clicked tab
event.target.classList.add('active');
}
function createTokenTooltip(tokenData, index) {
const tokenSpan = document.createElement('span');
tokenSpan.className = 'token';
tokenSpan.textContent = tokenData.token;
const tooltip = document.createElement('div');
tooltip.className = 'tooltip';
// Create tooltip content
let tooltipHTML = '';
tokenData.top5.forEach((item, rank) => {
const rankLabel = rank === 0 ? '1st (chosen)' :
rank === 1 ? '2nd' :
rank === 2 ? '3rd' :
rank === 3 ? '4th' : '5th';
tooltipHTML += `
<div class="tooltip-item">
<span class="tooltip-rank">${rankLabel}:</span>
<span class="tooltip-token">"${item.token}"</span>
<span class="tooltip-prob">${item.probability.toFixed(10)}%</span>
</div>
`;
});
tooltip.innerHTML = tooltipHTML;
tokenSpan.appendChild(tooltip);
return tokenSpan;
}
async function sendChat() {
const input = document.getElementById('chatInput');
const responseDiv = document.getElementById('chatResponse');
const queueDiv = document.getElementById('chatQueue');
const button = event.target;
if (!input.value.trim()) {
alert('Please enter a message');
return;
}
button.disabled = true;
button.innerHTML = '<span class="loading"></span> Processing...';
responseDiv.classList.remove('show', 'error');
responseDiv.innerHTML = '';
queueDiv.classList.add('show');
queueDiv.textContent = 'Adding to queue...';
try {
const response = await fetch('/chat', {
method: 'POST',
headers: {
'Content-Type': 'application/json',
},
body: JSON.stringify({ message: input.value })
});
const data = await response.json();
if (response.ok) {
// Clear previous content
responseDiv.innerHTML = '';
// Check if we have token data
if (data.tokens && data.tokens.length > 0) {
// Create interactive token display
data.tokens.forEach((tokenData, index) => {
const tokenElement = createTokenTooltip(tokenData, index);
responseDiv.appendChild(tokenElement);
});
} else {
// Fallback to plain text
responseDiv.textContent = data.response;
}
responseDiv.classList.remove('error');
queueDiv.classList.remove('show');
} else {
responseDiv.textContent = 'Error: ' + (data.detail || data.error || 'Unknown error');
responseDiv.classList.add('error');
queueDiv.classList.remove('show');
}
responseDiv.classList.add('show');
} catch (error) {
responseDiv.textContent = 'Error: ' + error.message;
responseDiv.classList.add('error', 'show');
queueDiv.classList.remove('show');
} finally {
button.disabled = false;
button.textContent = 'Send Message';
}
}
async function sendSummarize() {
const input = document.getElementById('summarizeInput');
const responseDiv = document.getElementById('summarizeResponse');
const queueDiv = document.getElementById('summarizeQueue');
const button = event.target;
if (!input.value.trim()) {
alert('Please enter text to summarize');
return;
}
button.disabled = true;
button.innerHTML = '<span class="loading"></span> Processing...';
responseDiv.classList.remove('show', 'error');
queueDiv.classList.add('show');
queueDiv.textContent = 'Adding to queue...';
try {
const response = await fetch('/summarize', {
method: 'POST',
headers: {
'Content-Type': 'application/json',
},
body: JSON.stringify({ text: input.value })
});
const data = await response.json();
if (response.ok) {
responseDiv.textContent = data.summary;
responseDiv.classList.remove('error');
queueDiv.classList.remove('show');
} else {
responseDiv.textContent = 'Error: ' + (data.detail || data.error || 'Unknown error');
responseDiv.classList.add('error');
queueDiv.classList.remove('show');
}
responseDiv.classList.add('show');
} catch (error) {
responseDiv.textContent = 'Error: ' + error.message;
responseDiv.classList.add('error', 'show');
queueDiv.classList.remove('show');
} finally {
button.disabled = false;
button.textContent = 'Generate Summary';
}
}
// Allow Enter key to submit (with Shift+Enter for new line)
document.getElementById('chatInput').addEventListener('keydown', function(e) {
if (e.key === 'Enter' && !e.shiftKey) {
e.preventDefault();
sendChat();
}
});
</script>
</body>
</html>
"""
@app.post("/chat")
async def chat(request: ChatRequest):
queue_stats["chat"] += 1
request_data = {"message": request.message, "result": None}
await chat_queue.put(request_data)
# Wait for result with timeout
timeout = 120 # 2 minutes
start_time = datetime.now()
while request_data["result"] is None:
await asyncio.sleep(0.5)
if (datetime.now() - start_time).total_seconds() > timeout:
return JSONResponse(content={"error": "Request timeout"}, status_code=504)
result = request_data["result"]
if "error" in result:
return JSONResponse(content=result, status_code=500)
return JSONResponse(content=result)
@app.post("/summarize")
async def summarize(request: SummarizeRequest):
queue_stats["summarize"] += 1
request_data = {"text": request.text, "result": None}
await summarize_queue.put(request_data)
# Wait for result with timeout
timeout = 120 # 2 minutes
start_time = datetime.now()
while request_data["result"] is None:
await asyncio.sleep(0.5)
if (datetime.now() - start_time).total_seconds() > timeout:
return JSONResponse(content={"error": "Request timeout"}, status_code=504)
result = request_data["result"]
if "error" in result:
return JSONResponse(content=result, status_code=500)
return JSONResponse(content=result)
@app.get("/health")
async def health():
return {"status": "healthy", "device": device, "queue": queue_stats}
@app.get("/queue")
async def get_queue_status():
return {
"chat_queue": queue_stats["chat"],
"summarize_queue": queue_stats["summarize"]
}