worker-universal / worker_app.py
Bc-AI's picture
Update worker_app.py
657c142 verified
import os
import time
import json
import asyncio
from datetime import datetime
from typing import Dict, List, Optional
from fastapi import FastAPI, HTTPException
from fastapi.responses import StreamingResponse
import uvicorn
from pydantic import BaseModel
from shared.models import ChatRequest, ChatResponse, ChatMessage
import tensorflow as tf
import keras
import numpy as np
from tokenizers import Tokenizer
from huggingface_hub import hf_hub_download
import requests
from transformers import GPT2Tokenizer
from shared.model_manager import ModelManager
app = FastAPI(
title="Universal Worker Node for Sam-X Models",
description="Processing node that supports all Sam-X model types dynamically",
version="2.0.0"
)
# Global model manager instance
model_manager = ModelManager()
model_loaded = True # Always true since we're using lazy loading
# Performance optimizations
NUM_CORES = os.cpu_count() or 4
os.environ['TF_NUM_INTEROP_THREADS'] = str(NUM_CORES)
os.environ['TF_NUM_INTRAOP_THREADS'] = str(NUM_CORES)
os.environ['CUDA_VISIBLE_DEVICES'] = '-1' # Force CPU only
os.environ['TF_ENABLE_ONEDNN_OPTS'] = '1' # Intel optimization
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' # Reduce TF logging
# Configure TF threading
tf.config.threading.set_inter_op_parallelism_threads(NUM_CORES)
tf.config.threading.set_intra_op_parallelism_threads(NUM_CORES)
print(f"✅ CPU optimized: {NUM_CORES} threads, oneDNN enabled")
def format_chat_prompt(messages: List[Dict[str, str]]) -> str:
"""Format chat messages into a prompt for the model"""
prompt = ""
for msg in messages:
role = msg.get('role', 'user')
content = msg.get('content', '')
if role.lower() == 'user':
prompt += f"""
{content}
"""
elif role.lower() == 'assistant':
prompt += f"""
{content}
"""
else:
# System or other roles
prompt += f"{content}\n"
# Add assistant prefix for the response
prompt += """
"""
return prompt
def sample_token(logits, temperature=0.8, top_k=40, top_p=0.9, repetition_penalty=1.1):
"""Sample next token from logits"""
# Apply temperature
logits = logits / temperature
# Apply repetition penalty
if repetition_penalty != 1.0:
logits = np.where(logits < 0, logits * repetition_penalty, logits / repetition_penalty)
# Convert to probabilities
probs = np.exp(logits - np.max(logits)) # Numerical stability
probs = probs / np.sum(probs)
# Top-k filtering
if top_k > 0 and top_k < len(probs):
top_k_idx = np.argpartition(probs, -top_k)[-top_k:]
top_k_probs = probs[top_k_idx]
top_k_probs = top_k_probs / np.sum(top_k_probs) # Normalize
sampled_idx = np.random.choice(len(top_k_idx), p=top_k_probs)
return top_k_idx[sampled_idx]
# Top-p (nucleus) sampling
if top_p < 1.0:
sorted_idx = np.argsort(probs)[::-1]
sorted_probs = probs[sorted_idx]
cumulative_probs = np.cumsum(sorted_probs)
cutoff_idx = np.searchsorted(cumulative_probs, top_p)
cutoff_idx = min(cutoff_idx + 1, len(sorted_idx))
nucleus_idx = sorted_idx[:cutoff_idx]
nucleus_probs = probs[nucleus_idx]
nucleus_probs = nucleus_probs / np.sum(nucleus_probs) # Normalize
sampled_idx = np.random.choice(len(nucleus_idx), p=nucleus_probs)
return nucleus_idx[sampled_idx]
# Regular sampling
return np.random.choice(len(probs), p=probs)
def generate_response(model: keras.Model, tokenizer: Tokenizer, config: dict,
prompt: str, max_tokens: int = 512, temperature: float = 0.8,
top_k: int = 40, top_p: float = 0.9, repetition_penalty: float = 1.1) -> str:
"""Generate response from the model"""
# Tokenize the prompt
prompt_ids = tokenizer.encode(prompt).ids
input_ids = tf.constant([prompt_ids], dtype=tf.int32)
# Run the model
generated_ids = []
current_ids = input_ids
# Process tokens one by one (simplified generation without KV cache for this example)
for i in range(max_tokens):
with tf.device('/CPU:0'): # Use CPU for inference
logits, _ = model(current_ids, training=False, use_cache=False)
next_token_logits = logits[0, -1, :].numpy()
# Sample next token
next_token_id = sample_token(next_token_logits, temperature, top_k, top_p, repetition_penalty)
# Add to generated sequence
generated_ids.append(next_token_id)
current_ids = tf.constant([[next_token_id]], dtype=tf.int32)
# Stop if we hit an end token
eos_token_id = config.get('eos_token_id', 50256)
stop_token_ids = [eos_token_id, tokenizer.token_to_id("\n"), tokenizer.token_to_id("<im end for model tun>")]
if next_token_id in stop_token_ids and next_token_id is not None:
break
# Decode the generated tokens
generated_text = tokenizer.decode(generated_ids)
# Clean up the response
# Remove any end tokens that might have been included
stop_tokens = ["\n", "<im end for model tun>"]
for token in stop_tokens:
idx = generated_text.find(token)
if idx != -1:
generated_text = generated_text[:idx]
return generated_text.strip()
async def generate_streaming_response(model: keras.Model, tokenizer: Tokenizer, config: dict,
prompt: str, max_tokens: int = 512, temperature: float = 0.8,
top_k: int = 40, top_p: float = 0.9, repetition_penalty: float = 1.1):
"""Generate streaming response from the model"""
import json
import time
# Tokenize the prompt
prompt_ids = tokenizer.encode(prompt).ids
input_ids = tf.constant([prompt_ids], dtype=tf.int32)
# Run the model
generated_ids = []
current_ids = input_ids
# Send initial chunk with role
initial_chunk = {
"id": f"chat-{int(time.time())}",
"object": "chat.completion.chunk",
"created": int(time.time()),
"model": "dynamic_model", # Will be set by the calling function
"choices": [{
"index": 0,
"delta": {"role": "assistant", "content": ""},
"finish_reason": None
}]
}
yield f"data: {json.dumps(initial_chunk)}\n\n"
# Process tokens one by one with streaming - this is where SACCP token distribution happens
for i in range(max_tokens):
with tf.device('/CPU:0'): # Use CPU for inference
logits, _ = model(current_ids, training=False, use_cache=False)
next_token_logits = logits[0, -1, :].numpy()
# Sample next token
next_token_id = sample_token(next_token_logits, temperature, top_k, top_p, repetition_penalty)
# Add to generated sequence
generated_ids.append(next_token_id)
current_ids = tf.constant([[next_token_id]], dtype=tf.int32)
# Decode this single token to get text
token_text = tokenizer.decode([next_token_id])
# Create chunk with the token
chunk = {
"id": f"chat-{int(time.time())}",
"object": "chat.completion.chunk",
"created": int(time.time()),
"model": "dynamic_model", # Will be set by the calling function
"choices": [{
"index": 0,
"delta": {"content": token_text},
"finish_reason": None
}]
}
yield f"data: {json.dumps(chunk)}\n\n"
# Check if we should stop
eos_token_id = config.get('eos_token_id', 50256)
stop_token_ids = [eos_token_id, tokenizer.token_to_id("\n"), tokenizer.token_to_id("<im end for model tun>")]
if next_token_id in stop_token_ids and next_token_id is not None:
break
# Send final chunk
final_chunk = {
"id": f"chat-{int(time.time())}",
"object": "chat.completion.chunk",
"created": int(time.time()),
"model": "dynamic_model", # Will be set by the calling function
"choices": [{
"index": 0,
"delta": {},
"finish_reason": "stop"
}]
}
yield f"data: {json.dumps(final_chunk)}\n\n"
async def generate_token_by_token_streaming_response(model: keras.Model, tokenizer: Tokenizer, config: dict,
prompt: str, max_tokens: int = 512, temperature: float = 0.8,
top_k: int = 40, top_p: float = 0.9, repetition_penalty: float = 1.1):
"""Generate streaming response with token-by-token processing, suitable for SACCP distribution"""
import json
import time
# Tokenize the prompt
prompt_ids = tokenizer.encode(prompt).ids
input_ids = tf.constant([prompt_ids], dtype=tf.int32)
# Initialize sequence
current_ids = input_ids
generated_text = ""
# Send initial chunk with role
initial_chunk = {
"id": f"chat-{int(time.time())}",
"object": "chat.completion.chunk",
"created": int(time.time()),
"model": "dynamic_model",
"choices": [{
"index": 0,
"delta": {"role": "assistant", "content": ""},
"finish_reason": None
}]
}
yield f"data: {json.dumps(initial_chunk)}\n\n"
for i in range(max_tokens):
# Process one token at a time (in a real SACCP scenario, this could be distributed)
with tf.device('/CPU:0'):
logits, _ = model(current_ids, training=False, use_cache=False)
next_token_logits = logits[0, -1, :].numpy()
# Sample next token
next_token_id = sample_token(next_token_logits, temperature, top_k, top_p, repetition_penalty)
# Decode token to text
token_text = tokenizer.decode([next_token_id])
# Update the generated text
generated_text += token_text
# Create and yield chunk for this token
chunk = {
"id": f"token-{i}-{int(time.time())}",
"object": "chat.completion.chunk",
"created": int(time.time()),
"model": "dynamic_model",
"choices": [{
"index": 0,
"delta": {"content": token_text},
"finish_reason": None
}]
}
yield f"data: {json.dumps(chunk)}\n\n"
# Prepare for next iteration
current_ids = tf.constant([[next_token_id]], dtype=tf.int32)
# Check for stopping conditions
eos_token_id = config.get('eos_token_id', 50256)
stop_token_ids = [eos_token_id, tokenizer.token_to_id("\n"), tokenizer.token_to_id("<im end for model tun>")]
if next_token_id in stop_token_ids and next_token_id is not None:
break
# Final chunk
final_chunk = {
"id": f"chat-{int(time.time())}",
"object": "chat.completion.chunk",
"created": int(time.time()),
"model": "dynamic_model",
"choices": [{
"index": 0,
"delta": {},
"finish_reason": "stop"
}]
}
yield f"data: {json.dumps(final_chunk)}\n\n"
@app.on_event("startup")
def startup_event():
"""Initialize model manager on startup"""
global model_loaded
print("Initializing universal worker...")
print(f"Available models: {model_manager.list_available_models()}")
try:
print("✅ Universal worker initialized successfully!")
print("This worker can dynamically load any Sam-X model based on requests")
except Exception as e:
print(f"❌ Worker initialization failed: {e}")
model_loaded = False
@app.post("/chat/completions")
async def chat_completions(request: ChatRequest):
"""Process chat completion request"""
global model_loaded
try:
# Extract model type from request
model_type = request.model.lower()
# Validate model type
available_models = model_manager.list_available_models()
if model_type not in available_models:
# Find closest matching model
matching_models = [m for m in available_models if model_type in m or m in model_type]
if matching_models:
model_type = matching_models[0] # Use first available match
else:
raise HTTPException(
status_code=400,
detail=f"Model {request.model} not available. Available models: {available_models}"
)
# Get the appropriate model and tokenizer for this request
model, tokenizer, config = model_manager.get_model(model_type)
# Format the messages into a single prompt
messages = [{"role": msg.role, "content": msg.content} for msg in request.messages]
prompt = format_chat_prompt(messages)
# If streaming is requested, return StreamingResponse
if request.stream:
async def generate():
async for chunk in generate_streaming_response(
model=model,
tokenizer=tokenizer,
config=config,
prompt=prompt,
max_tokens=request.max_tokens,
temperature=request.temperature,
top_k=request.top_k,
top_p=request.top_p,
repetition_penalty=request.repetition_penalty
):
# Update model name in chunk
import json
chunk_data = json.loads(chunk[7:-4]) # Extract JSON from "data: {...}\n\n"
chunk_data["model"] = request.model
updated_chunk = f"data: {json.dumps(chunk_data)}\n\n"
yield updated_chunk
return StreamingResponse(generate(), media_type="text/event-stream")
# Otherwise, generate full response
start_time = time.time()
response_text = generate_response(
model=model,
tokenizer=tokenizer,
config=config,
prompt=prompt,
max_tokens=request.max_tokens,
temperature=request.temperature,
top_k=request.top_k,
top_p=request.top_p,
repetition_penalty=request.repetition_penalty
)
processing_time = time.time() - start_time
# Create response in OpenAI-compatible format
response = ChatResponse(
id=f"chat-{int(time.time())}",
model=request.model, # Use original model name
choices=[
{
"index": 0,
"message": {"role": "assistant", "content": response_text},
"finish_reason": "stop"
}
],
usage={
"prompt_tokens": len(prompt),
"completion_tokens": len(response_text),
"total_tokens": len(prompt) + len(response_text)
}
)
print(f"Generated response in {processing_time:.2f}s for model {request.model} (loaded as {model_type})")
return response.dict()
except Exception as e:
print(f"Error processing request: {e}")
raise HTTPException(status_code=500, detail=f"Error processing request: {str(e)}")
@app.get("/health")
async def health_check():
"""Health check endpoint"""
return {
"status": "healthy" if model_loaded else "unhealthy",
"model_loaded": model_loaded,
"timestamp": int(time.time()),
"supported_models": model_manager.list_available_models(),
"loaded_models": list(model_manager.models.keys())
}
@app.get("/model-info")
async def model_info(model_type: str = "sam-x-large"):
"""Get information about a specific model"""
try:
if model_type not in model_manager.list_available_models():
raise HTTPException(
status_code=404,
detail=f"Model {model_type} not available. Available: {model_manager.list_available_models()}"
)
model, tokenizer, config = model_manager.get_model(model_type)
return {
"model_type": model_type,
"vocab_size": tokenizer.get_vocab_size(),
"parameters": int(model.count_params()) if model else 0,
"max_context_length": config.get('max_position_embeddings', 2048),
"loaded": model_manager.is_model_loaded(model_type),
"num_hidden_layers": config.get('num_hidden_layers', 12),
"hidden_size": config.get('hidden_size', 768),
"num_attention_heads": config.get('num_attention_heads', 12)
}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error getting model info: {str(e)}")
@app.get("/models")
async def list_models():
"""List all available models"""
return {
"object": "list",
"data": [
{
"id": model_name,
"object": "model",
"created": int(time.time()),
"owned_by": "universal-worker"
}
for model_name in model_manager.list_available_models()
]
}
@app.post("/saccp/process-task")
async def process_saccp_task(request: dict):
"""Process a SACCP task - interface for distributed computing"""
try:
task_type = request.get("task_type", "inference")
model_type = request.get("model_name", "sam-x-large")
task_data = request.get("task_data", {})
# Get the appropriate model and tokenizer
model, tokenizer, config = model_manager.get_model(model_type)
if task_type == "inference":
prompt = task_data.get("prompt", "")
max_tokens = task_data.get("max_tokens", 512)
temperature = task_data.get("temperature", 0.8)
result = generate_response(
model=model,
tokenizer=tokenizer,
config=config,
prompt=prompt,
max_tokens=max_tokens,
temperature=temperature
)
return {
"status": "success",
"result": result,
"model_used": model_type
}
elif task_type == "token_generation":
# Handle token-by-token generation task for autoregressive models
current_context = task_data.get("current_context", [])
generation_params = task_data.get("generation_params", {})
if not current_context:
# If no context provided, return error
raise HTTPException(status_code=400, detail="Current context required for token generation")
# Convert context to tensor
input_ids = tf.constant([current_context], dtype=tf.int32)
# Run the model on the context
with tf.device('/CPU:0'):
logits, _ = model(input_ids, training=False, use_cache=False)
# Get logits for the last token position
next_token_logits = logits[0, -1, :].numpy()
# Apply generation parameters
temperature = generation_params.get("temperature", 0.8)
top_k = generation_params.get("top_k", 40)
top_p = generation_params.get("top_p", 0.9)
repetition_penalty = generation_params.get("repetition_penalty", 1.1)
# Sample next token
next_token_id = sample_token(next_token_logits, temperature, top_k, top_p, repetition_penalty)
# Decode token to text
token_text = tokenizer.decode([next_token_id])
return {
"status": "success",
"token_id": int(next_token_id),
"token_text": token_text,
"model_used": model_type,
"next_position": len(current_context)
}
else:
# For other task types, we can extend this
raise HTTPException(status_code=400, detail=f"Task type {task_type} not supported")
except Exception as e:
print(f"Error processing SACCP task: {e}")
raise HTTPException(status_code=500, detail=f"Error processing SACCP task: {str(e)}")
if __name__ == "__main__":
port = int(os.getenv("PORT", 8000))
uvicorn.run(app, host="0.0.0.0", port=port)