import os import time import json import asyncio from datetime import datetime from typing import Dict, List, Optional from fastapi import FastAPI, HTTPException from fastapi.responses import StreamingResponse import uvicorn from pydantic import BaseModel from shared.models import ChatRequest, ChatResponse, ChatMessage import tensorflow as tf import keras import numpy as np from tokenizers import Tokenizer from huggingface_hub import hf_hub_download import requests from transformers import GPT2Tokenizer from shared.model_manager import ModelManager app = FastAPI( title="Universal Worker Node for Sam-X Models", description="Processing node that supports all Sam-X model types dynamically", version="2.0.0" ) # Global model manager instance model_manager = ModelManager() model_loaded = True # Always true since we're using lazy loading # Performance optimizations NUM_CORES = os.cpu_count() or 4 os.environ['TF_NUM_INTEROP_THREADS'] = str(NUM_CORES) os.environ['TF_NUM_INTRAOP_THREADS'] = str(NUM_CORES) os.environ['CUDA_VISIBLE_DEVICES'] = '-1' # Force CPU only os.environ['TF_ENABLE_ONEDNN_OPTS'] = '1' # Intel optimization os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' # Reduce TF logging # Configure TF threading tf.config.threading.set_inter_op_parallelism_threads(NUM_CORES) tf.config.threading.set_intra_op_parallelism_threads(NUM_CORES) print(f"✅ CPU optimized: {NUM_CORES} threads, oneDNN enabled") def format_chat_prompt(messages: List[Dict[str, str]]) -> str: """Format chat messages into a prompt for the model""" prompt = "" for msg in messages: role = msg.get('role', 'user') content = msg.get('content', '') if role.lower() == 'user': prompt += f""" {content} """ elif role.lower() == 'assistant': prompt += f""" {content} """ else: # System or other roles prompt += f"{content}\n" # Add assistant prefix for the response prompt += """ """ return prompt def sample_token(logits, temperature=0.8, top_k=40, top_p=0.9, repetition_penalty=1.1): """Sample next token from logits""" # Apply temperature logits = logits / temperature # Apply repetition penalty if repetition_penalty != 1.0: logits = np.where(logits < 0, logits * repetition_penalty, logits / repetition_penalty) # Convert to probabilities probs = np.exp(logits - np.max(logits)) # Numerical stability probs = probs / np.sum(probs) # Top-k filtering if top_k > 0 and top_k < len(probs): top_k_idx = np.argpartition(probs, -top_k)[-top_k:] top_k_probs = probs[top_k_idx] top_k_probs = top_k_probs / np.sum(top_k_probs) # Normalize sampled_idx = np.random.choice(len(top_k_idx), p=top_k_probs) return top_k_idx[sampled_idx] # Top-p (nucleus) sampling if top_p < 1.0: sorted_idx = np.argsort(probs)[::-1] sorted_probs = probs[sorted_idx] cumulative_probs = np.cumsum(sorted_probs) cutoff_idx = np.searchsorted(cumulative_probs, top_p) cutoff_idx = min(cutoff_idx + 1, len(sorted_idx)) nucleus_idx = sorted_idx[:cutoff_idx] nucleus_probs = probs[nucleus_idx] nucleus_probs = nucleus_probs / np.sum(nucleus_probs) # Normalize sampled_idx = np.random.choice(len(nucleus_idx), p=nucleus_probs) return nucleus_idx[sampled_idx] # Regular sampling return np.random.choice(len(probs), p=probs) def generate_response(model: keras.Model, tokenizer: Tokenizer, config: dict, prompt: str, max_tokens: int = 512, temperature: float = 0.8, top_k: int = 40, top_p: float = 0.9, repetition_penalty: float = 1.1) -> str: """Generate response from the model""" # Tokenize the prompt prompt_ids = tokenizer.encode(prompt).ids input_ids = tf.constant([prompt_ids], dtype=tf.int32) # Run the model generated_ids = [] current_ids = input_ids # Process tokens one by one (simplified generation without KV cache for this example) for i in range(max_tokens): with tf.device('/CPU:0'): # Use CPU for inference logits, _ = model(current_ids, training=False, use_cache=False) next_token_logits = logits[0, -1, :].numpy() # Sample next token next_token_id = sample_token(next_token_logits, temperature, top_k, top_p, repetition_penalty) # Add to generated sequence generated_ids.append(next_token_id) current_ids = tf.constant([[next_token_id]], dtype=tf.int32) # Stop if we hit an end token eos_token_id = config.get('eos_token_id', 50256) stop_token_ids = [eos_token_id, tokenizer.token_to_id("\n"), tokenizer.token_to_id("")] if next_token_id in stop_token_ids and next_token_id is not None: break # Decode the generated tokens generated_text = tokenizer.decode(generated_ids) # Clean up the response # Remove any end tokens that might have been included stop_tokens = ["\n", ""] for token in stop_tokens: idx = generated_text.find(token) if idx != -1: generated_text = generated_text[:idx] return generated_text.strip() async def generate_streaming_response(model: keras.Model, tokenizer: Tokenizer, config: dict, prompt: str, max_tokens: int = 512, temperature: float = 0.8, top_k: int = 40, top_p: float = 0.9, repetition_penalty: float = 1.1): """Generate streaming response from the model""" import json import time # Tokenize the prompt prompt_ids = tokenizer.encode(prompt).ids input_ids = tf.constant([prompt_ids], dtype=tf.int32) # Run the model generated_ids = [] current_ids = input_ids # Send initial chunk with role initial_chunk = { "id": f"chat-{int(time.time())}", "object": "chat.completion.chunk", "created": int(time.time()), "model": "dynamic_model", # Will be set by the calling function "choices": [{ "index": 0, "delta": {"role": "assistant", "content": ""}, "finish_reason": None }] } yield f"data: {json.dumps(initial_chunk)}\n\n" # Process tokens one by one with streaming - this is where SACCP token distribution happens for i in range(max_tokens): with tf.device('/CPU:0'): # Use CPU for inference logits, _ = model(current_ids, training=False, use_cache=False) next_token_logits = logits[0, -1, :].numpy() # Sample next token next_token_id = sample_token(next_token_logits, temperature, top_k, top_p, repetition_penalty) # Add to generated sequence generated_ids.append(next_token_id) current_ids = tf.constant([[next_token_id]], dtype=tf.int32) # Decode this single token to get text token_text = tokenizer.decode([next_token_id]) # Create chunk with the token chunk = { "id": f"chat-{int(time.time())}", "object": "chat.completion.chunk", "created": int(time.time()), "model": "dynamic_model", # Will be set by the calling function "choices": [{ "index": 0, "delta": {"content": token_text}, "finish_reason": None }] } yield f"data: {json.dumps(chunk)}\n\n" # Check if we should stop eos_token_id = config.get('eos_token_id', 50256) stop_token_ids = [eos_token_id, tokenizer.token_to_id("\n"), tokenizer.token_to_id("")] if next_token_id in stop_token_ids and next_token_id is not None: break # Send final chunk final_chunk = { "id": f"chat-{int(time.time())}", "object": "chat.completion.chunk", "created": int(time.time()), "model": "dynamic_model", # Will be set by the calling function "choices": [{ "index": 0, "delta": {}, "finish_reason": "stop" }] } yield f"data: {json.dumps(final_chunk)}\n\n" async def generate_token_by_token_streaming_response(model: keras.Model, tokenizer: Tokenizer, config: dict, prompt: str, max_tokens: int = 512, temperature: float = 0.8, top_k: int = 40, top_p: float = 0.9, repetition_penalty: float = 1.1): """Generate streaming response with token-by-token processing, suitable for SACCP distribution""" import json import time # Tokenize the prompt prompt_ids = tokenizer.encode(prompt).ids input_ids = tf.constant([prompt_ids], dtype=tf.int32) # Initialize sequence current_ids = input_ids generated_text = "" # Send initial chunk with role initial_chunk = { "id": f"chat-{int(time.time())}", "object": "chat.completion.chunk", "created": int(time.time()), "model": "dynamic_model", "choices": [{ "index": 0, "delta": {"role": "assistant", "content": ""}, "finish_reason": None }] } yield f"data: {json.dumps(initial_chunk)}\n\n" for i in range(max_tokens): # Process one token at a time (in a real SACCP scenario, this could be distributed) with tf.device('/CPU:0'): logits, _ = model(current_ids, training=False, use_cache=False) next_token_logits = logits[0, -1, :].numpy() # Sample next token next_token_id = sample_token(next_token_logits, temperature, top_k, top_p, repetition_penalty) # Decode token to text token_text = tokenizer.decode([next_token_id]) # Update the generated text generated_text += token_text # Create and yield chunk for this token chunk = { "id": f"token-{i}-{int(time.time())}", "object": "chat.completion.chunk", "created": int(time.time()), "model": "dynamic_model", "choices": [{ "index": 0, "delta": {"content": token_text}, "finish_reason": None }] } yield f"data: {json.dumps(chunk)}\n\n" # Prepare for next iteration current_ids = tf.constant([[next_token_id]], dtype=tf.int32) # Check for stopping conditions eos_token_id = config.get('eos_token_id', 50256) stop_token_ids = [eos_token_id, tokenizer.token_to_id("\n"), tokenizer.token_to_id("")] if next_token_id in stop_token_ids and next_token_id is not None: break # Final chunk final_chunk = { "id": f"chat-{int(time.time())}", "object": "chat.completion.chunk", "created": int(time.time()), "model": "dynamic_model", "choices": [{ "index": 0, "delta": {}, "finish_reason": "stop" }] } yield f"data: {json.dumps(final_chunk)}\n\n" @app.on_event("startup") def startup_event(): """Initialize model manager on startup""" global model_loaded print("Initializing universal worker...") print(f"Available models: {model_manager.list_available_models()}") try: print("✅ Universal worker initialized successfully!") print("This worker can dynamically load any Sam-X model based on requests") except Exception as e: print(f"❌ Worker initialization failed: {e}") model_loaded = False @app.post("/chat/completions") async def chat_completions(request: ChatRequest): """Process chat completion request""" global model_loaded try: # Extract model type from request model_type = request.model.lower() # Validate model type available_models = model_manager.list_available_models() if model_type not in available_models: # Find closest matching model matching_models = [m for m in available_models if model_type in m or m in model_type] if matching_models: model_type = matching_models[0] # Use first available match else: raise HTTPException( status_code=400, detail=f"Model {request.model} not available. Available models: {available_models}" ) # Get the appropriate model and tokenizer for this request model, tokenizer, config = model_manager.get_model(model_type) # Format the messages into a single prompt messages = [{"role": msg.role, "content": msg.content} for msg in request.messages] prompt = format_chat_prompt(messages) # If streaming is requested, return StreamingResponse if request.stream: async def generate(): async for chunk in generate_streaming_response( model=model, tokenizer=tokenizer, config=config, prompt=prompt, max_tokens=request.max_tokens, temperature=request.temperature, top_k=request.top_k, top_p=request.top_p, repetition_penalty=request.repetition_penalty ): # Update model name in chunk import json chunk_data = json.loads(chunk[7:-4]) # Extract JSON from "data: {...}\n\n" chunk_data["model"] = request.model updated_chunk = f"data: {json.dumps(chunk_data)}\n\n" yield updated_chunk return StreamingResponse(generate(), media_type="text/event-stream") # Otherwise, generate full response start_time = time.time() response_text = generate_response( model=model, tokenizer=tokenizer, config=config, prompt=prompt, max_tokens=request.max_tokens, temperature=request.temperature, top_k=request.top_k, top_p=request.top_p, repetition_penalty=request.repetition_penalty ) processing_time = time.time() - start_time # Create response in OpenAI-compatible format response = ChatResponse( id=f"chat-{int(time.time())}", model=request.model, # Use original model name choices=[ { "index": 0, "message": {"role": "assistant", "content": response_text}, "finish_reason": "stop" } ], usage={ "prompt_tokens": len(prompt), "completion_tokens": len(response_text), "total_tokens": len(prompt) + len(response_text) } ) print(f"Generated response in {processing_time:.2f}s for model {request.model} (loaded as {model_type})") return response.dict() except Exception as e: print(f"Error processing request: {e}") raise HTTPException(status_code=500, detail=f"Error processing request: {str(e)}") @app.get("/health") async def health_check(): """Health check endpoint""" return { "status": "healthy" if model_loaded else "unhealthy", "model_loaded": model_loaded, "timestamp": int(time.time()), "supported_models": model_manager.list_available_models(), "loaded_models": list(model_manager.models.keys()) } @app.get("/model-info") async def model_info(model_type: str = "sam-x-large"): """Get information about a specific model""" try: if model_type not in model_manager.list_available_models(): raise HTTPException( status_code=404, detail=f"Model {model_type} not available. Available: {model_manager.list_available_models()}" ) model, tokenizer, config = model_manager.get_model(model_type) return { "model_type": model_type, "vocab_size": tokenizer.get_vocab_size(), "parameters": int(model.count_params()) if model else 0, "max_context_length": config.get('max_position_embeddings', 2048), "loaded": model_manager.is_model_loaded(model_type), "num_hidden_layers": config.get('num_hidden_layers', 12), "hidden_size": config.get('hidden_size', 768), "num_attention_heads": config.get('num_attention_heads', 12) } except Exception as e: raise HTTPException(status_code=500, detail=f"Error getting model info: {str(e)}") @app.get("/models") async def list_models(): """List all available models""" return { "object": "list", "data": [ { "id": model_name, "object": "model", "created": int(time.time()), "owned_by": "universal-worker" } for model_name in model_manager.list_available_models() ] } @app.post("/saccp/process-task") async def process_saccp_task(request: dict): """Process a SACCP task - interface for distributed computing""" try: task_type = request.get("task_type", "inference") model_type = request.get("model_name", "sam-x-large") task_data = request.get("task_data", {}) # Get the appropriate model and tokenizer model, tokenizer, config = model_manager.get_model(model_type) if task_type == "inference": prompt = task_data.get("prompt", "") max_tokens = task_data.get("max_tokens", 512) temperature = task_data.get("temperature", 0.8) result = generate_response( model=model, tokenizer=tokenizer, config=config, prompt=prompt, max_tokens=max_tokens, temperature=temperature ) return { "status": "success", "result": result, "model_used": model_type } elif task_type == "token_generation": # Handle token-by-token generation task for autoregressive models current_context = task_data.get("current_context", []) generation_params = task_data.get("generation_params", {}) if not current_context: # If no context provided, return error raise HTTPException(status_code=400, detail="Current context required for token generation") # Convert context to tensor input_ids = tf.constant([current_context], dtype=tf.int32) # Run the model on the context with tf.device('/CPU:0'): logits, _ = model(input_ids, training=False, use_cache=False) # Get logits for the last token position next_token_logits = logits[0, -1, :].numpy() # Apply generation parameters temperature = generation_params.get("temperature", 0.8) top_k = generation_params.get("top_k", 40) top_p = generation_params.get("top_p", 0.9) repetition_penalty = generation_params.get("repetition_penalty", 1.1) # Sample next token next_token_id = sample_token(next_token_logits, temperature, top_k, top_p, repetition_penalty) # Decode token to text token_text = tokenizer.decode([next_token_id]) return { "status": "success", "token_id": int(next_token_id), "token_text": token_text, "model_used": model_type, "next_position": len(current_context) } else: # For other task types, we can extend this raise HTTPException(status_code=400, detail=f"Task type {task_type} not supported") except Exception as e: print(f"Error processing SACCP task: {e}") raise HTTPException(status_code=500, detail=f"Error processing SACCP task: {str(e)}") if __name__ == "__main__": port = int(os.getenv("PORT", 8000)) uvicorn.run(app, host="0.0.0.0", port=port)