Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, Request | |
| from fastapi.responses import HTMLResponse, JSONResponse | |
| from pydantic import BaseModel | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline | |
| import torch | |
| import torch.nn.functional as F | |
| import asyncio | |
| from datetime import datetime | |
| from typing import List, Dict | |
| # Initialize FastAPI app | |
| app = FastAPI(title="LLM Inference Server") | |
| # Determine device (CPU for free tier) | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| print(f"Using device: {device}") | |
| # Request queues | |
| chat_queue = asyncio.Queue() | |
| summarize_queue = asyncio.Queue() | |
| queue_stats = {"chat": 0, "summarize": 0} | |
| # Load models | |
| print("Loading Qwen model for chat...") | |
| chat_model = AutoModelForCausalLM.from_pretrained( | |
| "Qwen/Qwen2.5-0.5B-Instruct", | |
| torch_dtype=torch.float32 if device == "cpu" else torch.float16, | |
| device_map=device | |
| ) | |
| chat_tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct") | |
| print("Loading summarization model...") | |
| summarizer = pipeline( | |
| "summarization", | |
| model="Falconsai/text_summarization", | |
| device=device | |
| ) | |
| print("Models loaded successfully!") | |
| # Request models | |
| class ChatRequest(BaseModel): | |
| message: str | |
| class SummarizeRequest(BaseModel): | |
| text: str | |
| def generate_with_token_probs(model, tokenizer, messages: List[Dict], max_new_tokens: int = 512): | |
| """ | |
| Generate text and capture top-5 token predictions with probabilities for each step. | |
| """ | |
| # Apply chat template | |
| text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) | |
| inputs = tokenizer(text, return_tensors="pt").to(model.device) | |
| # Generate with scores | |
| outputs = model.generate( | |
| **inputs, | |
| max_new_tokens=max_new_tokens, | |
| temperature=0.7, | |
| do_sample=True, | |
| output_scores=True, | |
| return_dict_in_generate=True, | |
| pad_token_id=tokenizer.eos_token_id | |
| ) | |
| # Extract generated tokens (excluding input) | |
| generated_tokens = outputs.sequences[0][inputs.input_ids.shape[1]:] | |
| # Process scores to get top-5 predictions for each token | |
| token_data = [] | |
| for idx, score in enumerate(outputs.scores): | |
| # Apply softmax to get probabilities | |
| probs = F.softmax(score[0], dim=-1) | |
| # Get top 5 predictions | |
| top5_probs, top5_indices = torch.topk(probs, k=5) | |
| # Decode tokens | |
| top5_tokens = [tokenizer.decode([token_id]) for token_id in top5_indices] | |
| top5_probs_list = [float(prob) * 100 for prob in top5_probs] # Convert to percentage | |
| # Build alternatives list | |
| alternatives = [] | |
| for token, prob in zip(top5_tokens, top5_probs_list): | |
| alternatives.append({ | |
| "token": token, | |
| "probability": round(prob, 10) | |
| }) | |
| # Get the actual generated token | |
| actual_token = tokenizer.decode([generated_tokens[idx]]) | |
| token_data.append({ | |
| "token": actual_token, | |
| "top5": alternatives | |
| }) | |
| # Build full response text | |
| full_response = tokenizer.decode(generated_tokens, skip_special_tokens=True) | |
| return { | |
| "response": full_response, | |
| "tokens": token_data | |
| } | |
| # Queue processor for chat with token probabilities | |
| async def process_chat_queue(): | |
| while True: | |
| request_data = await chat_queue.get() | |
| try: | |
| messages = [{"role": "user", "content": request_data["message"]}] | |
| result = generate_with_token_probs( | |
| chat_model, | |
| chat_tokenizer, | |
| messages, | |
| max_new_tokens=512 | |
| ) | |
| request_data["result"] = result | |
| except Exception as e: | |
| request_data["result"] = {"error": str(e)} | |
| finally: | |
| queue_stats["chat"] = max(0, queue_stats["chat"] - 1) | |
| chat_queue.task_done() | |
| # Queue processor for summarization | |
| async def process_summarize_queue(): | |
| while True: | |
| request_data = await summarize_queue.get() | |
| try: | |
| summary = summarizer( | |
| request_data["text"], | |
| max_length=130, | |
| min_length=30, | |
| do_sample=False | |
| ) | |
| request_data["result"] = {"summary": summary[0]["summary_text"]} | |
| except Exception as e: | |
| request_data["result"] = {"error": str(e)} | |
| finally: | |
| queue_stats["summarize"] = max(0, queue_stats["summarize"] - 1) | |
| summarize_queue.task_done() | |
| # Start queue processors on startup | |
| async def startup_event(): | |
| asyncio.create_task(process_chat_queue()) | |
| asyncio.create_task(process_summarize_queue()) | |
| # Custom HTML UI | |
| async def root(): | |
| return """ | |
| <!DOCTYPE html> | |
| <html> | |
| <head> | |
| <title>LLM Inference Server</title> | |
| <style> | |
| * { | |
| margin: 0; | |
| padding: 0; | |
| box-sizing: border-box; | |
| } | |
| body { | |
| font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif; | |
| background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); | |
| min-height: 100vh; | |
| display: flex; | |
| justify-content: center; | |
| align-items: center; | |
| padding: 20px; | |
| } | |
| .container { | |
| background: white; | |
| border-radius: 20px; | |
| box-shadow: 0 20px 60px rgba(0,0,0,0.3); | |
| max-width: 800px; | |
| width: 100%; | |
| overflow: hidden; | |
| } | |
| .header { | |
| background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); | |
| color: white; | |
| padding: 30px; | |
| text-align: center; | |
| } | |
| h1 { | |
| font-size: 2.5em; | |
| margin-bottom: 10px; | |
| } | |
| .subtitle { | |
| opacity: 0.9; | |
| font-size: 1.1em; | |
| } | |
| .tabs { | |
| display: flex; | |
| background: #f5f5f5; | |
| border-bottom: 2px solid #e0e0e0; | |
| } | |
| .tab { | |
| flex: 1; | |
| padding: 20px; | |
| text-align: center; | |
| cursor: pointer; | |
| font-weight: 600; | |
| transition: all 0.3s; | |
| background: #f5f5f5; | |
| border: none; | |
| font-size: 1.1em; | |
| color: black; | |
| } | |
| .tab:hover { | |
| background: #e8e8e8; | |
| } | |
| .tab.active { | |
| background: white; | |
| color: #667eea; | |
| border-bottom: 3px solid #667eea; | |
| } | |
| .content { | |
| padding: 30px; | |
| } | |
| .tab-content { | |
| display: none; | |
| } | |
| .tab-content.active { | |
| display: block; | |
| } | |
| textarea { | |
| width: 100%; | |
| padding: 15px; | |
| border: 2px solid #e0e0e0; | |
| border-radius: 10px; | |
| font-size: 1em; | |
| font-family: inherit; | |
| resize: vertical; | |
| transition: border-color 0.3s; | |
| } | |
| textarea:focus { | |
| outline: none; | |
| border-color: #667eea; | |
| } | |
| button { | |
| background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); | |
| color: white; | |
| border: none; | |
| padding: 15px 40px; | |
| font-size: 1.1em; | |
| border-radius: 10px; | |
| cursor: pointer; | |
| margin-top: 20px; | |
| transition: transform 0.2s, box-shadow 0.2s; | |
| font-weight: 600; | |
| } | |
| button:hover { | |
| transform: translateY(-2px); | |
| box-shadow: 0 5px 20px rgba(102, 126, 234, 0.4); | |
| } | |
| button:active { | |
| transform: translateY(0); | |
| } | |
| button:disabled { | |
| opacity: 0.6; | |
| cursor: not-allowed; | |
| transform: none; | |
| } | |
| .response { | |
| margin-top: 20px; | |
| padding: 20px; | |
| background: #f9f9f9; | |
| border-radius: 10px; | |
| border-left: 4px solid #667eea; | |
| white-space: pre-wrap; | |
| word-wrap: break-word; | |
| display: none; | |
| line-height: 1.6; | |
| } | |
| .response.show { | |
| display: block; | |
| } | |
| /* Token visualization styles */ | |
| .token { | |
| display: inline; | |
| position: relative; | |
| cursor: help; | |
| padding: 2px 1px; | |
| transition: background-color 0.2s; | |
| } | |
| .token:hover { | |
| background-color: #fff3cd; | |
| border-radius: 3px; | |
| } | |
| .tooltip { | |
| visibility: hidden; | |
| position: absolute; | |
| bottom: 125%; | |
| left: 50%; | |
| transform: translateX(-50%); | |
| background-color: #333; | |
| color: white; | |
| padding: 12px; | |
| border-radius: 8px; | |
| font-size: 0.85em; | |
| white-space: nowrap; | |
| z-index: 1000; | |
| box-shadow: 0 4px 12px rgba(0,0,0,0.3); | |
| min-width: 200px; | |
| } | |
| .tooltip::after { | |
| content: ""; | |
| position: absolute; | |
| top: 100%; | |
| left: 50%; | |
| transform: translateX(-50%); | |
| border-width: 6px; | |
| border-style: solid; | |
| border-color: #333 transparent transparent transparent; | |
| } | |
| .token:hover .tooltip { | |
| visibility: visible; | |
| opacity: 1; | |
| } | |
| .tooltip-item { | |
| display: flex; | |
| justify-content: space-between; | |
| padding: 3px 0; | |
| border-bottom: 1px solid rgba(255,255,255,0.1); | |
| } | |
| .tooltip-item:last-child { | |
| border-bottom: none; | |
| } | |
| .tooltip-rank { | |
| color: #ffd700; | |
| font-weight: 600; | |
| margin-right: 8px; | |
| } | |
| .tooltip-token { | |
| font-family: monospace; | |
| color: #fff; | |
| margin-right: 8px; | |
| } | |
| .tooltip-prob { | |
| color: #90ee90; | |
| font-weight: 500; | |
| } | |
| .loading { | |
| display: inline-block; | |
| width: 20px; | |
| height: 20px; | |
| border: 3px solid rgba(255,255,255,.3); | |
| border-radius: 50%; | |
| border-top-color: white; | |
| animation: spin 1s ease-in-out infinite; | |
| } | |
| @keyframes spin { | |
| to { transform: rotate(360deg); } | |
| } | |
| .error { | |
| color: #d32f2f; | |
| background: #ffebee; | |
| border-left-color: #d32f2f; | |
| } | |
| .queue-info { | |
| margin-top: 10px; | |
| padding: 10px; | |
| background: #e3f2fd; | |
| border-radius: 5px; | |
| font-size: 0.9em; | |
| color: #1976d2; | |
| display: none; | |
| } | |
| .queue-info.show { | |
| display: block; | |
| } | |
| </style> | |
| </head> | |
| <body> | |
| <div class="container"> | |
| <div class="header"> | |
| <h1>🤖 LLM Inference Server</h1> | |
| <p class="subtitle">Powered by Qwen & Falconsai</p> | |
| </div> | |
| <div class="tabs"> | |
| <button class="tab active" onclick="switchTab('chat')">💬 Chat</button> | |
| <button class="tab" onclick="switchTab('summarize')">📝 Summarize</button> | |
| </div> | |
| <div class="content"> | |
| <div id="chat" class="tab-content active"> | |
| <h2 style="margin-bottom: 15px;">Chat with Qwen</h2> | |
| <textarea id="chatInput" rows="4" placeholder="Type your message here..."></textarea> | |
| <button onclick="sendChat()">Send Message</button> | |
| <div id="chatQueue" class="queue-info"></div> | |
| <div id="chatResponse" class="response"></div> | |
| </div> | |
| <div id="summarize" class="tab-content"> | |
| <h2 style="margin-bottom: 15px;">Summarize Text</h2> | |
| <textarea id="summarizeInput" rows="8" placeholder="Paste the text you want to summarize..."></textarea> | |
| <button onclick="sendSummarize()">Generate Summary</button> | |
| <div id="summarizeQueue" class="queue-info"></div> | |
| <div id="summarizeResponse" class="response"></div> | |
| </div> | |
| </div> | |
| </div> | |
| <script> | |
| function switchTab(tabName) { | |
| // Hide all tab contents | |
| document.querySelectorAll('.tab-content').forEach(content => { | |
| content.classList.remove('active'); | |
| }); | |
| // Remove active class from all tabs | |
| document.querySelectorAll('.tab').forEach(tab => { | |
| tab.classList.remove('active'); | |
| }); | |
| // Show selected tab content | |
| document.getElementById(tabName).classList.add('active'); | |
| // Add active class to clicked tab | |
| event.target.classList.add('active'); | |
| } | |
| function createTokenTooltip(tokenData, index) { | |
| const tokenSpan = document.createElement('span'); | |
| tokenSpan.className = 'token'; | |
| tokenSpan.textContent = tokenData.token; | |
| const tooltip = document.createElement('div'); | |
| tooltip.className = 'tooltip'; | |
| // Create tooltip content | |
| let tooltipHTML = ''; | |
| tokenData.top5.forEach((item, rank) => { | |
| const rankLabel = rank === 0 ? '1st (chosen)' : | |
| rank === 1 ? '2nd' : | |
| rank === 2 ? '3rd' : | |
| rank === 3 ? '4th' : '5th'; | |
| tooltipHTML += ` | |
| <div class="tooltip-item"> | |
| <span class="tooltip-rank">${rankLabel}:</span> | |
| <span class="tooltip-token">"${item.token}"</span> | |
| <span class="tooltip-prob">${item.probability.toFixed(10)}%</span> | |
| </div> | |
| `; | |
| }); | |
| tooltip.innerHTML = tooltipHTML; | |
| tokenSpan.appendChild(tooltip); | |
| return tokenSpan; | |
| } | |
| async function sendChat() { | |
| const input = document.getElementById('chatInput'); | |
| const responseDiv = document.getElementById('chatResponse'); | |
| const queueDiv = document.getElementById('chatQueue'); | |
| const button = event.target; | |
| if (!input.value.trim()) { | |
| alert('Please enter a message'); | |
| return; | |
| } | |
| button.disabled = true; | |
| button.innerHTML = '<span class="loading"></span> Processing...'; | |
| responseDiv.classList.remove('show', 'error'); | |
| responseDiv.innerHTML = ''; | |
| queueDiv.classList.add('show'); | |
| queueDiv.textContent = 'Adding to queue...'; | |
| try { | |
| const response = await fetch('/chat', { | |
| method: 'POST', | |
| headers: { | |
| 'Content-Type': 'application/json', | |
| }, | |
| body: JSON.stringify({ message: input.value }) | |
| }); | |
| const data = await response.json(); | |
| if (response.ok) { | |
| // Clear previous content | |
| responseDiv.innerHTML = ''; | |
| // Check if we have token data | |
| if (data.tokens && data.tokens.length > 0) { | |
| // Create interactive token display | |
| data.tokens.forEach((tokenData, index) => { | |
| const tokenElement = createTokenTooltip(tokenData, index); | |
| responseDiv.appendChild(tokenElement); | |
| }); | |
| } else { | |
| // Fallback to plain text | |
| responseDiv.textContent = data.response; | |
| } | |
| responseDiv.classList.remove('error'); | |
| queueDiv.classList.remove('show'); | |
| } else { | |
| responseDiv.textContent = 'Error: ' + (data.detail || data.error || 'Unknown error'); | |
| responseDiv.classList.add('error'); | |
| queueDiv.classList.remove('show'); | |
| } | |
| responseDiv.classList.add('show'); | |
| } catch (error) { | |
| responseDiv.textContent = 'Error: ' + error.message; | |
| responseDiv.classList.add('error', 'show'); | |
| queueDiv.classList.remove('show'); | |
| } finally { | |
| button.disabled = false; | |
| button.textContent = 'Send Message'; | |
| } | |
| } | |
| async function sendSummarize() { | |
| const input = document.getElementById('summarizeInput'); | |
| const responseDiv = document.getElementById('summarizeResponse'); | |
| const queueDiv = document.getElementById('summarizeQueue'); | |
| const button = event.target; | |
| if (!input.value.trim()) { | |
| alert('Please enter text to summarize'); | |
| return; | |
| } | |
| button.disabled = true; | |
| button.innerHTML = '<span class="loading"></span> Processing...'; | |
| responseDiv.classList.remove('show', 'error'); | |
| queueDiv.classList.add('show'); | |
| queueDiv.textContent = 'Adding to queue...'; | |
| try { | |
| const response = await fetch('/summarize', { | |
| method: 'POST', | |
| headers: { | |
| 'Content-Type': 'application/json', | |
| }, | |
| body: JSON.stringify({ text: input.value }) | |
| }); | |
| const data = await response.json(); | |
| if (response.ok) { | |
| responseDiv.textContent = data.summary; | |
| responseDiv.classList.remove('error'); | |
| queueDiv.classList.remove('show'); | |
| } else { | |
| responseDiv.textContent = 'Error: ' + (data.detail || data.error || 'Unknown error'); | |
| responseDiv.classList.add('error'); | |
| queueDiv.classList.remove('show'); | |
| } | |
| responseDiv.classList.add('show'); | |
| } catch (error) { | |
| responseDiv.textContent = 'Error: ' + error.message; | |
| responseDiv.classList.add('error', 'show'); | |
| queueDiv.classList.remove('show'); | |
| } finally { | |
| button.disabled = false; | |
| button.textContent = 'Generate Summary'; | |
| } | |
| } | |
| // Allow Enter key to submit (with Shift+Enter for new line) | |
| document.getElementById('chatInput').addEventListener('keydown', function(e) { | |
| if (e.key === 'Enter' && !e.shiftKey) { | |
| e.preventDefault(); | |
| sendChat(); | |
| } | |
| }); | |
| </script> | |
| </body> | |
| </html> | |
| """ | |
| async def chat(request: ChatRequest): | |
| queue_stats["chat"] += 1 | |
| request_data = {"message": request.message, "result": None} | |
| await chat_queue.put(request_data) | |
| # Wait for result with timeout | |
| timeout = 120 # 2 minutes | |
| start_time = datetime.now() | |
| while request_data["result"] is None: | |
| await asyncio.sleep(0.5) | |
| if (datetime.now() - start_time).total_seconds() > timeout: | |
| return JSONResponse(content={"error": "Request timeout"}, status_code=504) | |
| result = request_data["result"] | |
| if "error" in result: | |
| return JSONResponse(content=result, status_code=500) | |
| return JSONResponse(content=result) | |
| async def summarize(request: SummarizeRequest): | |
| queue_stats["summarize"] += 1 | |
| request_data = {"text": request.text, "result": None} | |
| await summarize_queue.put(request_data) | |
| # Wait for result with timeout | |
| timeout = 120 # 2 minutes | |
| start_time = datetime.now() | |
| while request_data["result"] is None: | |
| await asyncio.sleep(0.5) | |
| if (datetime.now() - start_time).total_seconds() > timeout: | |
| return JSONResponse(content={"error": "Request timeout"}, status_code=504) | |
| result = request_data["result"] | |
| if "error" in result: | |
| return JSONResponse(content=result, status_code=500) | |
| return JSONResponse(content=result) | |
| async def health(): | |
| return {"status": "healthy", "device": device, "queue": queue_stats} | |
| async def get_queue_status(): | |
| return { | |
| "chat_queue": queue_stats["chat"], | |
| "summarize_queue": queue_stats["summarize"] | |
| } | |