Spaces:
Runtime error
Runtime error
| import os | |
| import time | |
| import json | |
| import asyncio | |
| from datetime import datetime | |
| from typing import Dict, List, Optional | |
| from fastapi import FastAPI, HTTPException | |
| from fastapi.responses import StreamingResponse | |
| import uvicorn | |
| from pydantic import BaseModel | |
| from shared.models import ChatRequest, ChatResponse, ChatMessage | |
| import tensorflow as tf | |
| import keras | |
| import numpy as np | |
| from tokenizers import Tokenizer | |
| from huggingface_hub import hf_hub_download | |
| import requests | |
| from transformers import GPT2Tokenizer | |
| from shared.model_manager import ModelManager | |
| app = FastAPI( | |
| title="Universal Worker Node for Sam-X Models", | |
| description="Processing node that supports all Sam-X model types dynamically", | |
| version="2.0.0" | |
| ) | |
| # Global model manager instance | |
| model_manager = ModelManager() | |
| model_loaded = True # Always true since we're using lazy loading | |
| # Performance optimizations | |
| NUM_CORES = os.cpu_count() or 4 | |
| os.environ['TF_NUM_INTEROP_THREADS'] = str(NUM_CORES) | |
| os.environ['TF_NUM_INTRAOP_THREADS'] = str(NUM_CORES) | |
| os.environ['CUDA_VISIBLE_DEVICES'] = '-1' # Force CPU only | |
| os.environ['TF_ENABLE_ONEDNN_OPTS'] = '1' # Intel optimization | |
| os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' # Reduce TF logging | |
| # Configure TF threading | |
| tf.config.threading.set_inter_op_parallelism_threads(NUM_CORES) | |
| tf.config.threading.set_intra_op_parallelism_threads(NUM_CORES) | |
| print(f"✅ CPU optimized: {NUM_CORES} threads, oneDNN enabled") | |
| def format_chat_prompt(messages: List[Dict[str, str]]) -> str: | |
| """Format chat messages into a prompt for the model""" | |
| prompt = "" | |
| for msg in messages: | |
| role = msg.get('role', 'user') | |
| content = msg.get('content', '') | |
| if role.lower() == 'user': | |
| prompt += f""" | |
| {content} | |
| """ | |
| elif role.lower() == 'assistant': | |
| prompt += f""" | |
| {content} | |
| """ | |
| else: | |
| # System or other roles | |
| prompt += f"{content}\n" | |
| # Add assistant prefix for the response | |
| prompt += """ | |
| """ | |
| return prompt | |
| def sample_token(logits, temperature=0.8, top_k=40, top_p=0.9, repetition_penalty=1.1): | |
| """Sample next token from logits""" | |
| # Apply temperature | |
| logits = logits / temperature | |
| # Apply repetition penalty | |
| if repetition_penalty != 1.0: | |
| logits = np.where(logits < 0, logits * repetition_penalty, logits / repetition_penalty) | |
| # Convert to probabilities | |
| probs = np.exp(logits - np.max(logits)) # Numerical stability | |
| probs = probs / np.sum(probs) | |
| # Top-k filtering | |
| if top_k > 0 and top_k < len(probs): | |
| top_k_idx = np.argpartition(probs, -top_k)[-top_k:] | |
| top_k_probs = probs[top_k_idx] | |
| top_k_probs = top_k_probs / np.sum(top_k_probs) # Normalize | |
| sampled_idx = np.random.choice(len(top_k_idx), p=top_k_probs) | |
| return top_k_idx[sampled_idx] | |
| # Top-p (nucleus) sampling | |
| if top_p < 1.0: | |
| sorted_idx = np.argsort(probs)[::-1] | |
| sorted_probs = probs[sorted_idx] | |
| cumulative_probs = np.cumsum(sorted_probs) | |
| cutoff_idx = np.searchsorted(cumulative_probs, top_p) | |
| cutoff_idx = min(cutoff_idx + 1, len(sorted_idx)) | |
| nucleus_idx = sorted_idx[:cutoff_idx] | |
| nucleus_probs = probs[nucleus_idx] | |
| nucleus_probs = nucleus_probs / np.sum(nucleus_probs) # Normalize | |
| sampled_idx = np.random.choice(len(nucleus_idx), p=nucleus_probs) | |
| return nucleus_idx[sampled_idx] | |
| # Regular sampling | |
| return np.random.choice(len(probs), p=probs) | |
| def generate_response(model: keras.Model, tokenizer: Tokenizer, config: dict, | |
| prompt: str, max_tokens: int = 512, temperature: float = 0.8, | |
| top_k: int = 40, top_p: float = 0.9, repetition_penalty: float = 1.1) -> str: | |
| """Generate response from the model""" | |
| # Tokenize the prompt | |
| prompt_ids = tokenizer.encode(prompt).ids | |
| input_ids = tf.constant([prompt_ids], dtype=tf.int32) | |
| # Run the model | |
| generated_ids = [] | |
| current_ids = input_ids | |
| # Process tokens one by one (simplified generation without KV cache for this example) | |
| for i in range(max_tokens): | |
| with tf.device('/CPU:0'): # Use CPU for inference | |
| logits, _ = model(current_ids, training=False, use_cache=False) | |
| next_token_logits = logits[0, -1, :].numpy() | |
| # Sample next token | |
| next_token_id = sample_token(next_token_logits, temperature, top_k, top_p, repetition_penalty) | |
| # Add to generated sequence | |
| generated_ids.append(next_token_id) | |
| current_ids = tf.constant([[next_token_id]], dtype=tf.int32) | |
| # Stop if we hit an end token | |
| eos_token_id = config.get('eos_token_id', 50256) | |
| stop_token_ids = [eos_token_id, tokenizer.token_to_id("\n"), tokenizer.token_to_id("<im end for model tun>")] | |
| if next_token_id in stop_token_ids and next_token_id is not None: | |
| break | |
| # Decode the generated tokens | |
| generated_text = tokenizer.decode(generated_ids) | |
| # Clean up the response | |
| # Remove any end tokens that might have been included | |
| stop_tokens = ["\n", "<im end for model tun>"] | |
| for token in stop_tokens: | |
| idx = generated_text.find(token) | |
| if idx != -1: | |
| generated_text = generated_text[:idx] | |
| return generated_text.strip() | |
| async def generate_streaming_response(model: keras.Model, tokenizer: Tokenizer, config: dict, | |
| prompt: str, max_tokens: int = 512, temperature: float = 0.8, | |
| top_k: int = 40, top_p: float = 0.9, repetition_penalty: float = 1.1): | |
| """Generate streaming response from the model""" | |
| import json | |
| import time | |
| # Tokenize the prompt | |
| prompt_ids = tokenizer.encode(prompt).ids | |
| input_ids = tf.constant([prompt_ids], dtype=tf.int32) | |
| # Run the model | |
| generated_ids = [] | |
| current_ids = input_ids | |
| # Send initial chunk with role | |
| initial_chunk = { | |
| "id": f"chat-{int(time.time())}", | |
| "object": "chat.completion.chunk", | |
| "created": int(time.time()), | |
| "model": "dynamic_model", # Will be set by the calling function | |
| "choices": [{ | |
| "index": 0, | |
| "delta": {"role": "assistant", "content": ""}, | |
| "finish_reason": None | |
| }] | |
| } | |
| yield f"data: {json.dumps(initial_chunk)}\n\n" | |
| # Process tokens one by one with streaming - this is where SACCP token distribution happens | |
| for i in range(max_tokens): | |
| with tf.device('/CPU:0'): # Use CPU for inference | |
| logits, _ = model(current_ids, training=False, use_cache=False) | |
| next_token_logits = logits[0, -1, :].numpy() | |
| # Sample next token | |
| next_token_id = sample_token(next_token_logits, temperature, top_k, top_p, repetition_penalty) | |
| # Add to generated sequence | |
| generated_ids.append(next_token_id) | |
| current_ids = tf.constant([[next_token_id]], dtype=tf.int32) | |
| # Decode this single token to get text | |
| token_text = tokenizer.decode([next_token_id]) | |
| # Create chunk with the token | |
| chunk = { | |
| "id": f"chat-{int(time.time())}", | |
| "object": "chat.completion.chunk", | |
| "created": int(time.time()), | |
| "model": "dynamic_model", # Will be set by the calling function | |
| "choices": [{ | |
| "index": 0, | |
| "delta": {"content": token_text}, | |
| "finish_reason": None | |
| }] | |
| } | |
| yield f"data: {json.dumps(chunk)}\n\n" | |
| # Check if we should stop | |
| eos_token_id = config.get('eos_token_id', 50256) | |
| stop_token_ids = [eos_token_id, tokenizer.token_to_id("\n"), tokenizer.token_to_id("<im end for model tun>")] | |
| if next_token_id in stop_token_ids and next_token_id is not None: | |
| break | |
| # Send final chunk | |
| final_chunk = { | |
| "id": f"chat-{int(time.time())}", | |
| "object": "chat.completion.chunk", | |
| "created": int(time.time()), | |
| "model": "dynamic_model", # Will be set by the calling function | |
| "choices": [{ | |
| "index": 0, | |
| "delta": {}, | |
| "finish_reason": "stop" | |
| }] | |
| } | |
| yield f"data: {json.dumps(final_chunk)}\n\n" | |
| async def generate_token_by_token_streaming_response(model: keras.Model, tokenizer: Tokenizer, config: dict, | |
| prompt: str, max_tokens: int = 512, temperature: float = 0.8, | |
| top_k: int = 40, top_p: float = 0.9, repetition_penalty: float = 1.1): | |
| """Generate streaming response with token-by-token processing, suitable for SACCP distribution""" | |
| import json | |
| import time | |
| # Tokenize the prompt | |
| prompt_ids = tokenizer.encode(prompt).ids | |
| input_ids = tf.constant([prompt_ids], dtype=tf.int32) | |
| # Initialize sequence | |
| current_ids = input_ids | |
| generated_text = "" | |
| # Send initial chunk with role | |
| initial_chunk = { | |
| "id": f"chat-{int(time.time())}", | |
| "object": "chat.completion.chunk", | |
| "created": int(time.time()), | |
| "model": "dynamic_model", | |
| "choices": [{ | |
| "index": 0, | |
| "delta": {"role": "assistant", "content": ""}, | |
| "finish_reason": None | |
| }] | |
| } | |
| yield f"data: {json.dumps(initial_chunk)}\n\n" | |
| for i in range(max_tokens): | |
| # Process one token at a time (in a real SACCP scenario, this could be distributed) | |
| with tf.device('/CPU:0'): | |
| logits, _ = model(current_ids, training=False, use_cache=False) | |
| next_token_logits = logits[0, -1, :].numpy() | |
| # Sample next token | |
| next_token_id = sample_token(next_token_logits, temperature, top_k, top_p, repetition_penalty) | |
| # Decode token to text | |
| token_text = tokenizer.decode([next_token_id]) | |
| # Update the generated text | |
| generated_text += token_text | |
| # Create and yield chunk for this token | |
| chunk = { | |
| "id": f"token-{i}-{int(time.time())}", | |
| "object": "chat.completion.chunk", | |
| "created": int(time.time()), | |
| "model": "dynamic_model", | |
| "choices": [{ | |
| "index": 0, | |
| "delta": {"content": token_text}, | |
| "finish_reason": None | |
| }] | |
| } | |
| yield f"data: {json.dumps(chunk)}\n\n" | |
| # Prepare for next iteration | |
| current_ids = tf.constant([[next_token_id]], dtype=tf.int32) | |
| # Check for stopping conditions | |
| eos_token_id = config.get('eos_token_id', 50256) | |
| stop_token_ids = [eos_token_id, tokenizer.token_to_id("\n"), tokenizer.token_to_id("<im end for model tun>")] | |
| if next_token_id in stop_token_ids and next_token_id is not None: | |
| break | |
| # Final chunk | |
| final_chunk = { | |
| "id": f"chat-{int(time.time())}", | |
| "object": "chat.completion.chunk", | |
| "created": int(time.time()), | |
| "model": "dynamic_model", | |
| "choices": [{ | |
| "index": 0, | |
| "delta": {}, | |
| "finish_reason": "stop" | |
| }] | |
| } | |
| yield f"data: {json.dumps(final_chunk)}\n\n" | |
| def startup_event(): | |
| """Initialize model manager on startup""" | |
| global model_loaded | |
| print("Initializing universal worker...") | |
| print(f"Available models: {model_manager.list_available_models()}") | |
| try: | |
| print("✅ Universal worker initialized successfully!") | |
| print("This worker can dynamically load any Sam-X model based on requests") | |
| except Exception as e: | |
| print(f"❌ Worker initialization failed: {e}") | |
| model_loaded = False | |
| async def chat_completions(request: ChatRequest): | |
| """Process chat completion request""" | |
| global model_loaded | |
| try: | |
| # Extract model type from request | |
| model_type = request.model.lower() | |
| # Validate model type | |
| available_models = model_manager.list_available_models() | |
| if model_type not in available_models: | |
| # Find closest matching model | |
| matching_models = [m for m in available_models if model_type in m or m in model_type] | |
| if matching_models: | |
| model_type = matching_models[0] # Use first available match | |
| else: | |
| raise HTTPException( | |
| status_code=400, | |
| detail=f"Model {request.model} not available. Available models: {available_models}" | |
| ) | |
| # Get the appropriate model and tokenizer for this request | |
| model, tokenizer, config = model_manager.get_model(model_type) | |
| # Format the messages into a single prompt | |
| messages = [{"role": msg.role, "content": msg.content} for msg in request.messages] | |
| prompt = format_chat_prompt(messages) | |
| # If streaming is requested, return StreamingResponse | |
| if request.stream: | |
| async def generate(): | |
| async for chunk in generate_streaming_response( | |
| model=model, | |
| tokenizer=tokenizer, | |
| config=config, | |
| prompt=prompt, | |
| max_tokens=request.max_tokens, | |
| temperature=request.temperature, | |
| top_k=request.top_k, | |
| top_p=request.top_p, | |
| repetition_penalty=request.repetition_penalty | |
| ): | |
| # Update model name in chunk | |
| import json | |
| chunk_data = json.loads(chunk[7:-4]) # Extract JSON from "data: {...}\n\n" | |
| chunk_data["model"] = request.model | |
| updated_chunk = f"data: {json.dumps(chunk_data)}\n\n" | |
| yield updated_chunk | |
| return StreamingResponse(generate(), media_type="text/event-stream") | |
| # Otherwise, generate full response | |
| start_time = time.time() | |
| response_text = generate_response( | |
| model=model, | |
| tokenizer=tokenizer, | |
| config=config, | |
| prompt=prompt, | |
| max_tokens=request.max_tokens, | |
| temperature=request.temperature, | |
| top_k=request.top_k, | |
| top_p=request.top_p, | |
| repetition_penalty=request.repetition_penalty | |
| ) | |
| processing_time = time.time() - start_time | |
| # Create response in OpenAI-compatible format | |
| response = ChatResponse( | |
| id=f"chat-{int(time.time())}", | |
| model=request.model, # Use original model name | |
| choices=[ | |
| { | |
| "index": 0, | |
| "message": {"role": "assistant", "content": response_text}, | |
| "finish_reason": "stop" | |
| } | |
| ], | |
| usage={ | |
| "prompt_tokens": len(prompt), | |
| "completion_tokens": len(response_text), | |
| "total_tokens": len(prompt) + len(response_text) | |
| } | |
| ) | |
| print(f"Generated response in {processing_time:.2f}s for model {request.model} (loaded as {model_type})") | |
| return response.dict() | |
| except Exception as e: | |
| print(f"Error processing request: {e}") | |
| raise HTTPException(status_code=500, detail=f"Error processing request: {str(e)}") | |
| async def health_check(): | |
| """Health check endpoint""" | |
| return { | |
| "status": "healthy" if model_loaded else "unhealthy", | |
| "model_loaded": model_loaded, | |
| "timestamp": int(time.time()), | |
| "supported_models": model_manager.list_available_models(), | |
| "loaded_models": list(model_manager.models.keys()) | |
| } | |
| async def model_info(model_type: str = "sam-x-large"): | |
| """Get information about a specific model""" | |
| try: | |
| if model_type not in model_manager.list_available_models(): | |
| raise HTTPException( | |
| status_code=404, | |
| detail=f"Model {model_type} not available. Available: {model_manager.list_available_models()}" | |
| ) | |
| model, tokenizer, config = model_manager.get_model(model_type) | |
| return { | |
| "model_type": model_type, | |
| "vocab_size": tokenizer.get_vocab_size(), | |
| "parameters": int(model.count_params()) if model else 0, | |
| "max_context_length": config.get('max_position_embeddings', 2048), | |
| "loaded": model_manager.is_model_loaded(model_type), | |
| "num_hidden_layers": config.get('num_hidden_layers', 12), | |
| "hidden_size": config.get('hidden_size', 768), | |
| "num_attention_heads": config.get('num_attention_heads', 12) | |
| } | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Error getting model info: {str(e)}") | |
| async def list_models(): | |
| """List all available models""" | |
| return { | |
| "object": "list", | |
| "data": [ | |
| { | |
| "id": model_name, | |
| "object": "model", | |
| "created": int(time.time()), | |
| "owned_by": "universal-worker" | |
| } | |
| for model_name in model_manager.list_available_models() | |
| ] | |
| } | |
| async def process_saccp_task(request: dict): | |
| """Process a SACCP task - interface for distributed computing""" | |
| try: | |
| task_type = request.get("task_type", "inference") | |
| model_type = request.get("model_name", "sam-x-large") | |
| task_data = request.get("task_data", {}) | |
| # Get the appropriate model and tokenizer | |
| model, tokenizer, config = model_manager.get_model(model_type) | |
| if task_type == "inference": | |
| prompt = task_data.get("prompt", "") | |
| max_tokens = task_data.get("max_tokens", 512) | |
| temperature = task_data.get("temperature", 0.8) | |
| result = generate_response( | |
| model=model, | |
| tokenizer=tokenizer, | |
| config=config, | |
| prompt=prompt, | |
| max_tokens=max_tokens, | |
| temperature=temperature | |
| ) | |
| return { | |
| "status": "success", | |
| "result": result, | |
| "model_used": model_type | |
| } | |
| elif task_type == "token_generation": | |
| # Handle token-by-token generation task for autoregressive models | |
| current_context = task_data.get("current_context", []) | |
| generation_params = task_data.get("generation_params", {}) | |
| if not current_context: | |
| # If no context provided, return error | |
| raise HTTPException(status_code=400, detail="Current context required for token generation") | |
| # Convert context to tensor | |
| input_ids = tf.constant([current_context], dtype=tf.int32) | |
| # Run the model on the context | |
| with tf.device('/CPU:0'): | |
| logits, _ = model(input_ids, training=False, use_cache=False) | |
| # Get logits for the last token position | |
| next_token_logits = logits[0, -1, :].numpy() | |
| # Apply generation parameters | |
| temperature = generation_params.get("temperature", 0.8) | |
| top_k = generation_params.get("top_k", 40) | |
| top_p = generation_params.get("top_p", 0.9) | |
| repetition_penalty = generation_params.get("repetition_penalty", 1.1) | |
| # Sample next token | |
| next_token_id = sample_token(next_token_logits, temperature, top_k, top_p, repetition_penalty) | |
| # Decode token to text | |
| token_text = tokenizer.decode([next_token_id]) | |
| return { | |
| "status": "success", | |
| "token_id": int(next_token_id), | |
| "token_text": token_text, | |
| "model_used": model_type, | |
| "next_position": len(current_context) | |
| } | |
| else: | |
| # For other task types, we can extend this | |
| raise HTTPException(status_code=400, detail=f"Task type {task_type} not supported") | |
| except Exception as e: | |
| print(f"Error processing SACCP task: {e}") | |
| raise HTTPException(status_code=500, detail=f"Error processing SACCP task: {str(e)}") | |
| if __name__ == "__main__": | |
| port = int(os.getenv("PORT", 8000)) | |
| uvicorn.run(app, host="0.0.0.0", port=port) |