import os import time import json import asyncio from datetime import datetime from typing import Dict, List, Optional from fastapi import FastAPI, HTTPException import uvicorn from pydantic import BaseModel from shared.models import ChatRequest, ChatResponse, ChatMessage import tensorflow as tf import keras import numpy as np from tokenizers import Tokenizer from huggingface_hub import hf_hub_download import requests from transformers import GPT2Tokenizer app = FastAPI( title="Worker Node for Sam-X Models", description="Processing node for Sam-X model inference", version="1.0.0" ) # Global variables for model and tokenizer tokenizer = None model = None model_loaded = False # Configuration MODEL_REPO = os.getenv("MODEL_REPO", "Smilyai-labs/Sam-large-2") MODEL_TYPE = os.getenv("MODEL_TYPE", "sam-x-nano") # Determines which model to load CACHE_DIR = "./model_cache" # Performance optimizations NUM_CORES = os.cpu_count() or 4 os.environ['TF_NUM_INTEROP_THREADS'] = str(NUM_CORES) os.environ['TF_NUM_INTRAOP_THREADS'] = str(NUM_CORES) os.environ['CUDA_VISIBLE_DEVICES'] = '-1' # Force CPU only os.environ['TF_ENABLE_ONEDNN_OPTS'] = '1' # Intel optimization os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' # Reduce TF logging # Configure TF threading tf.config.threading.set_inter_op_parallelism_threads(NUM_CORES) tf.config.threading.set_intra_op_parallelism_threads(NUM_CORES) print(f"✅ CPU optimized: {NUM_CORES} threads, oneDNN enabled") def load_tokenizer(): """Load the tokenizer from Hugging Face or local files""" global tokenizer print("🚀 Loading tokenizer...") try: # Try to load from Hugging Face from transformers import AutoTokenizer hf_tokenizer = AutoTokenizer.from_pretrained("gpt2") # Add special tokens specific to your models special_tokens = [" ", " ", " ", " ", "", ""] hf_tokenizer.add_special_tokens({"additional_special_tokens": special_tokens}) # Save temporarily to create tokenizers instance os.makedirs("./temp_tokenizer", exist_ok=True) hf_tokenizer.save_pretrained("./temp_tokenizer") tokenizer = Tokenizer.from_file("./temp_tokenizer/tokenizer.json") print(f"✅ Tokenizer loaded with vocab size: {tokenizer.get_vocab_size()}") except Exception as e: print(f"❌ Error loading tokenizer: {e}") raise def load_model(): """Load the specific model based on MODEL_TYPE environment variable""" global model, model_loaded print(f"🚀 Loading {MODEL_TYPE} model...") try: # Determine which model to load based on MODEL_TYPE if MODEL_TYPE == "sam-x-nano": # Load nano model config_path = hf_hub_download("Smilyai-labs/Sam-nano", "config.json", cache_dir=CACHE_DIR) with open(config_path, 'r') as f: config = json.load(f) elif MODEL_TYPE == "sam-x-mini": # Load mini model config_path = hf_hub_download("Smilyai-labs/Sam-mini", "config.json", cache_dir=CACHE_DIR) with open(config_path, 'r') as f: config = json.load(f) elif MODEL_TYPE == "sam-x-fast": # Load fast model config_path = hf_hub_download("Smilyai-labs/Sam-fast", "config.json", cache_dir=CACHE_DIR) with open(config_path, 'r') as f: config = json.load(f) else: # Default to large model # Load from the default repo config_path = hf_hub_download(MODEL_REPO, "config.json", cache_dir=CACHE_DIR) with open(config_path, 'r') as f: config = json.load(f) # Build model from config model_config = { 'vocab_size': config.get('vocab_size', 50432), 'd_model': config.get('hidden_size', 768), 'n_layers': config.get('num_hidden_layers', 12), 'n_heads': config.get('num_attention_heads', 12), 'ff_mult': config.get('intermediate_size', 3072) / config.get('hidden_size', 768), 'max_len': config.get('max_position_embeddings', 2048), 'dropout': 0.1, 'rope_theta': config.get('rope_theta', 10000) } from model_architecture import SAM1Model # Import from your architecture file model = SAM1Model(config=model_config) # Build model with dummy input dummy_input = tf.zeros((1, 16), dtype=tf.int32) _ = model(dummy_input, training=False, use_cache=False) print(f"✅ Model loaded: {config.get('num_hidden_layers', 12)} layers") # Try to load weights try: weights_path = hf_hub_download(MODEL_REPO, "model.weights.h5", cache_dir=CACHE_DIR) model.load_weights(weights_path) print("✅ Model weights loaded successfully!") except Exception as e: print(f"⚠️ Could not load weights, using random initialization: {e}") # Warm up the model print("🔥 Warming up model...") warmup_input = tf.constant([[1, 2, 3, 4, 5]], dtype=tf.int32) _, _ = model(warmup_input, training=False, use_cache=True) print("✅ Model warmed up") model_loaded = True except Exception as e: print(f"❌ Error loading model: {e}") raise def format_chat_prompt(messages: List[Dict[str, str]]) -> str: """Format chat messages into a prompt for the model""" prompt = "" for msg in messages: role = msg.get('role', 'user') content = msg.get('content', '') if role.lower() == 'user': prompt += f" {content} " elif role.lower() == 'assistant': prompt += f" {content} " else: # System or other roles prompt += f"{content}\n" # Add assistant prefix for the response prompt += " " return prompt def sample_token(logits, temperature=0.8, top_k=40, top_p=0.9, repetition_penalty=1.1): """Sample next token from logits""" # Apply temperature logits = logits / temperature # Apply repetition penalty if repetition_penalty != 1.0: logits = np.where(logits < 0, logits * repetition_penalty, logits / repetition_penalty) # Convert to probabilities probs = np.exp(logits - np.max(logits)) # Numerical stability probs = probs / np.sum(probs) # Top-k filtering if top_k > 0 and top_k < len(probs): top_k_idx = np.argpartition(probs, -top_k)[-top_k:] top_k_probs = probs[top_k_idx] top_k_probs = top_k_probs / np.sum(top_k_probs) # Normalize sampled_idx = np.random.choice(len(top_k_idx), p=top_k_probs) return top_k_idx[sampled_idx] # Top-p (nucleus) sampling if top_p < 1.0: sorted_idx = np.argsort(probs)[::-1] sorted_probs = probs[sorted_idx] cumulative_probs = np.cumsum(sorted_probs) cutoff_idx = np.searchsorted(cumulative_probs, top_p) cutoff_idx = min(cutoff_idx + 1, len(sorted_idx)) nucleus_idx = sorted_idx[:cutoff_idx] nucleus_probs = probs[nucleus_idx] nucleus_probs = nucleus_probs / np.sum(nucleus_probs) # Normalize sampled_idx = np.random.choice(len(nucleus_idx), p=nucleus_probs) return nucleus_idx[sampled_idx] # Regular sampling return np.random.choice(len(probs), p=probs) def generate_response(prompt: str, max_tokens: int = 512, temperature: float = 0.8, top_k: int = 40, top_p: float = 0.9, repetition_penalty: float = 1.1) -> str: """Generate response from the model""" global model, tokenizer if not model_loaded: raise Exception("Model not loaded") # Tokenize the prompt prompt_ids = tokenizer.encode(prompt).ids input_ids = tf.constant([prompt_ids], dtype=tf.int32) # Run the model generated_ids = [] current_ids = input_ids # Process tokens one by one (simplified generation without KV cache for this example) for i in range(max_tokens): with tf.device('/CPU:0'): # Use CPU for inference logits, _ = model(current_ids, training=False, use_cache=False) next_token_logits = logits[0, -1, :].numpy() # Sample next token next_token_id = sample_token(next_token_logits, temperature, top_k, top_p, repetition_penalty) # Add to generated sequence generated_ids.append(next_token_id) current_ids = tf.constant([[next_token_id]], dtype=tf.int32) # Stop if we hit an end token if next_token_id in [50256, tokenizer.token_to_id(" "), tokenizer.token_to_id("")]: break # Decode the generated tokens generated_text = tokenizer.decode(generated_ids) # Clean up the response # Remove any end tokens that might have been included stop_tokens = [" ", ""] for token in stop_tokens: idx = generated_text.find(token) if idx != -1: generated_text = generated_text[:idx] return generated_text.strip() @app.on_event("startup") def startup_event(): """Initialize model and tokenizer on startup""" global model_loaded print(f"Initializing worker for model type: {MODEL_TYPE}") try: load_tokenizer() load_model() print("✅ Worker initialized successfully!") except Exception as e: print(f"❌ Worker initialization failed: {e}") model_loaded = False @app.post("/chat/completions") async def chat_completions(request: ChatRequest): """Process chat completion request""" global model_loaded if not model_loaded: raise HTTPException(status_code=503, detail="Model not loaded") try: # Format the messages into a single prompt messages = [{"role": msg.role, "content": msg.content} for msg in request.messages] prompt = format_chat_prompt(messages) # Generate response start_time = time.time() response_text = generate_response( prompt=prompt, max_tokens=request.max_tokens, temperature=request.temperature, top_k=request.top_k, top_p=request.top_p, repetition_penalty=request.repetition_penalty ) processing_time = time.time() - start_time # Create response in OpenAI-compatible format response = ChatResponse( id=f"chat-{int(time.time())}", model=request.model, choices=[ { "index": 0, "message": {"role": "assistant", "content": response_text}, "finish_reason": "stop" } ], usage={ "prompt_tokens": len(prompt), "completion_tokens": len(response_text), "total_tokens": len(prompt) + len(response_text) } ) print(f"Generated response in {processing_time:.2f}s for model {request.model}") return response.dict() except Exception as e: print(f"Error processing request: {e}") raise HTTPException(status_code=500, detail=f"Error processing request: {str(e)}") @app.get("/health") async def health_check(): """Health check endpoint""" return { "status": "healthy" if model_loaded else "unhealthy", "model_type": MODEL_TYPE, "model_loaded": model_loaded, "timestamp": int(time.time()) } @app.get("/model-info") async def model_info(): """Get information about the loaded model""" if not model_loaded: raise HTTPException(status_code=404, detail="Model not loaded") return { "model_type": MODEL_TYPE, "vocab_size": tokenizer.get_vocab_size() if tokenizer else 0, "parameters": model.count_params() if model else 0, "max_context_length": 2048 # Default, would be from config } if __name__ == "__main__": port = int(os.getenv("PORT", 8000)) uvicorn.run(app, host="0.0.0.0", port=port)