Bc-AI commited on
Commit
666ed75
·
verified ·
1 Parent(s): 3f1837e

Upload folder using huggingface_hub

Browse files
Dockerfile ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.10-slim
2
+
3
+ WORKDIR /app
4
+
5
+ # Install system dependencies
6
+ RUN apt-get update && apt-get install -y \
7
+ gcc \
8
+ g++ \
9
+ && rm -rf /var/lib/apt/lists/*
10
+
11
+ # Copy requirements first to leverage Docker cache
12
+ COPY requirements.txt .
13
+ RUN pip install --no-cache-dir -r requirements.txt
14
+
15
+ # Copy application code
16
+ COPY worker_app.py .
17
+ COPY model_architecture.py .
18
+ COPY ../shared ./shared
19
+
20
+ # Expose port for the API
21
+ EXPOSE 8000
22
+
23
+ # Start the application
24
+ CMD ["python", "worker_app.py"]
README.md CHANGED
@@ -1,10 +1,12 @@
1
- ---
2
- title: Worker Large
3
- emoji: 🌍
4
- colorFrom: red
5
- colorTo: green
6
- sdk: docker
7
- pinned: false
8
- ---
9
-
10
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
1
+ # SACCP Worker_Large Node
2
+ This is a worker_large node in the SACCP (Scalable Accelerated Compute Protocol) distributed computing network.
3
+
4
+ ## Node Type: WORKER_LARGE
5
+ - Processes tasks according to SACCP protocol
6
+ - Contributes computational resources to the network
7
+ - Earns cloud credits for resource contribution
8
+
9
+ ## Architecture
10
+ - Built with FastAPI and TensorFlow/Keras
11
+ - Implements fault-tolerant operations
12
+ - Integrated with SACCP credit system
model_architecture.py ADDED
@@ -0,0 +1,205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tensorflow as tf
2
+ import keras
3
+ import numpy as np
4
+
5
+ @keras.saving.register_keras_serializable()
6
+ class RotaryEmbedding(keras.layers.Layer):
7
+ def __init__(self, dim, max_len=2048, theta=10000, **kwargs):
8
+ super().__init__(**kwargs)
9
+ self.dim = dim
10
+ self.max_len = max_len
11
+ self.theta = theta
12
+ self.built_cache = False
13
+ self.cos_cached = None
14
+ self.sin_cached = None
15
+
16
+ def build(self, input_shape):
17
+ super().build(input_shape)
18
+
19
+ def _build_cache(self):
20
+ if not self.built_cache:
21
+ inv_freq = 1.0 / (self.theta ** (tf.range(0, self.dim, 2, dtype=tf.float32) / self.dim))
22
+ t = tf.range(self.max_len, dtype=tf.float32)
23
+ freqs = tf.einsum("i,j->ij", t, inv_freq)
24
+ emb = tf.concat([freqs, freqs], axis=-1)
25
+ self.cos_cached = tf.constant(np.cos(emb.numpy()), dtype=tf.float32)
26
+ self.sin_cached = tf.constant(np.sin(emb.numpy()), dtype=tf.float32)
27
+ self.built_cache = True
28
+
29
+ def rotate_half(self, x):
30
+ x1, x2 = tf.split(x, 2, axis=-1)
31
+ return tf.concat([-x2, x1], axis=-1)
32
+
33
+ def call(self, q, k, offset=0):
34
+ """Apply rotary embeddings with position offset."""
35
+ self._build_cache()
36
+ seq_len = tf.shape(q)[2]
37
+ dtype = q.dtype
38
+
39
+ cos = tf.cast(self.cos_cached[offset:offset + seq_len, :], dtype)[None, None, :, :]
40
+ sin = tf.cast(self.sin_cached[offset:offset + seq_len, :], dtype)[None, None, :, :]
41
+
42
+ q_embed = (q * cos) + (self.rotate_half(q) * sin)
43
+ k_embed = (k * cos) + (self.rotate_half(k) * sin)
44
+ return q_embed, k_embed
45
+
46
+ def get_config(self):
47
+ config = super().get_config()
48
+ config.update({"dim": self.dim, "max_len": self.max_len, "theta": self.theta})
49
+ return config
50
+
51
+
52
+ @keras.saving.register_keras_serializable()
53
+ class RMSNorm(keras.layers.Layer):
54
+ def __init__(self, epsilon=1e-5, **kwargs):
55
+ super().__init__(**kwargs)
56
+ self.epsilon = epsilon
57
+ self.scale = None
58
+
59
+ def build(self, input_shape):
60
+ self.scale = self.add_weight(name="scale", shape=(input_shape[-1],), initializer="ones")
61
+ super().build(input_shape)
62
+
63
+ def call(self, x):
64
+ variance = tf.reduce_mean(tf.square(x), axis=-1, keepdims=True)
65
+ return x * tf.math.rsqrt(variance + self.epsilon) * self.scale
66
+
67
+ def get_config(self):
68
+ config = super().get_config()
69
+ config.update({"epsilon": self.epsilon})
70
+ return config
71
+
72
+
73
+ @keras.saving.register_keras_serializable()
74
+ class TransformerBlock(keras.layers.Layer):
75
+ def __init__(self, d_model, n_heads, ff_dim, dropout, max_len, rope_theta, layer_idx=0, **kwargs):
76
+ super().__init__(**kwargs)
77
+ self.d_model = d_model
78
+ self.n_heads = n_heads
79
+ self.ff_dim = ff_dim
80
+ self.dropout_rate = dropout
81
+ self.max_len = max_len
82
+ self.rope_theta = rope_theta
83
+ self.head_dim = d_model // n_heads
84
+ self.layer_idx = layer_idx
85
+
86
+ def build(self, input_shape):
87
+ self.pre_attn_norm = RMSNorm(name="pre_attn_norm")
88
+ self.pre_ffn_norm = RMSNorm(name="pre_ffn_norm")
89
+ self.q_proj = keras.layers.Dense(self.d_model, use_bias=False, name="q_proj")
90
+ self.k_proj = keras.layers.Dense(self.d_model, use_bias=False, name="k_proj")
91
+ self.v_proj = keras.layers.Dense(self.d_model, use_bias=False, name="v_proj")
92
+ self.out_proj = keras.layers.Dense(self.d_model, use_bias=False, name="o_proj")
93
+ self.rope = RotaryEmbedding(self.head_dim, max_len=self.max_len, theta=self.rope_theta)
94
+ self.gate_proj = keras.layers.Dense(self.ff_dim, use_bias=False, name="gate_proj")
95
+ self.up_proj = keras.layers.Dense(self.ff_dim, use_bias=False, name="up_proj")
96
+ self.down_proj = keras.layers.Dense(self.d_model, use_bias=False, name="down_proj")
97
+ self.dropout = keras.layers.Dropout(self.dropout_rate)
98
+ super().build(input_shape)
99
+
100
+ def call(self, x, training=None, past_kv=None, use_cache=False):
101
+ """Simplified call without KV cache for this example"""
102
+ B, T, D = tf.shape(x)[0], tf.shape(x)[1], self.d_model
103
+ dtype = x.dtype
104
+
105
+ res = x
106
+ y = self.pre_attn_norm(x)
107
+
108
+ # Multi-head attention
109
+ q = tf.transpose(tf.reshape(self.q_proj(y), [B, T, self.n_heads, self.head_dim]), [0, 2, 1, 3])
110
+ k = tf.transpose(tf.reshape(self.k_proj(y), [B, T, self.n_heads, self.head_dim]), [0, 2, 1, 3])
111
+ v = tf.transpose(tf.reshape(self.v_proj(y), [B, T, self.n_heads, self.head_dim]), [0, 2, 1, 3])
112
+
113
+ # Apply RoPE
114
+ q, k = self.rope(q, k, offset=0)
115
+
116
+ # Attention scores
117
+ scores = tf.matmul(q, k, transpose_b=True) / tf.sqrt(tf.cast(self.head_dim, dtype))
118
+
119
+ # Causal mask
120
+ mask = tf.linalg.band_part(tf.ones([T, T], dtype=dtype), -1, 0) # Upper triangular
121
+ mask = tf.where(mask == 0, tf.constant(-1e9, dtype=dtype), tf.constant(0.0, dtype=dtype))
122
+ scores = scores + mask[None, None, :, :]
123
+
124
+ attn = tf.nn.softmax(scores, axis=-1)
125
+ attn_out = tf.matmul(attn, v)
126
+ attn_out = tf.transpose(attn_out, [0, 2, 1, 3])
127
+ attn_out = tf.reshape(attn_out, [B, T, self.d_model])
128
+
129
+ x = res + self.dropout(self.out_proj(attn_out), training=training)
130
+
131
+ # FFN
132
+ res = x
133
+ y = self.pre_ffn_norm(x)
134
+ ffn = self.down_proj(keras.activations.silu(self.gate_proj(y)) * self.up_proj(y))
135
+ output = res + self.dropout(ffn, training=training)
136
+
137
+ return output, None # Return None for past_kv in this simplified version
138
+
139
+ def get_config(self):
140
+ config = super().get_config()
141
+ config.update({
142
+ "d_model": self.d_model,
143
+ "n_heads": self.n_heads,
144
+ "ff_dim": self.ff_dim,
145
+ "dropout": self.dropout_rate,
146
+ "max_len": self.max_len,
147
+ "rope_theta": self.rope_theta,
148
+ "layer_idx": self.layer_idx
149
+ })
150
+ return config
151
+
152
+
153
+ @keras.saving.register_keras_serializable()
154
+ class SAM1Model(keras.Model):
155
+ def __init__(self, **kwargs):
156
+ super().__init__()
157
+ if 'config' in kwargs and isinstance(kwargs['config'], dict):
158
+ self.cfg = kwargs['config']
159
+ elif 'vocab_size' in kwargs:
160
+ self.cfg = kwargs
161
+ else:
162
+ self.cfg = kwargs.get('cfg', kwargs)
163
+
164
+ self.embed = keras.layers.Embedding(self.cfg['vocab_size'], self.cfg['d_model'], name="embed_tokens")
165
+ ff_dim = int(self.cfg['d_model'] * self.cfg['ff_mult'])
166
+ block_args = {
167
+ 'd_model': self.cfg['d_model'],
168
+ 'n_heads': self.cfg['n_heads'],
169
+ 'ff_dim': ff_dim,
170
+ 'dropout': self.cfg['dropout'],
171
+ 'max_len': self.cfg['max_len'],
172
+ 'rope_theta': self.cfg['rope_theta']
173
+ }
174
+ self.blocks = [
175
+ TransformerBlock(name=f"block_{i}", layer_idx=i, **block_args)
176
+ for i in range(self.cfg['n_layers'])
177
+ ]
178
+ self.norm = RMSNorm(name="final_norm")
179
+ self.lm_head = keras.layers.Dense(self.cfg['vocab_size'], use_bias=False, name="lm_head")
180
+
181
+ def call(self, input_ids, training=None, past_kv=None, use_cache=False):
182
+ """
183
+ Simplified call without full KV cache implementation
184
+ """
185
+ x = self.embed(input_ids)
186
+
187
+ for block in self.blocks:
188
+ x, _ = block(x, training=training, past_kv=None, use_cache=False)
189
+
190
+ logits = self.lm_head(self.norm(x))
191
+ return logits, None # Return None for past_kv in this simplified version
192
+
193
+ def get_config(self):
194
+ base_config = super().get_config()
195
+ base_config['config'] = self.cfg
196
+ return base_config
197
+
198
+
199
+ def count_parameters(model):
200
+ """Count model parameters"""
201
+ total_params = 0
202
+ for weight in model.weights:
203
+ w = weight.numpy()
204
+ total_params += w.size
205
+ return total_params
requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Requirements for Worker Nodes
2
+ keras==2.15.0
3
+ tensorflow==2.15.0
4
+ fastapi==0.104.1
5
+ uvicorn==0.24.0
6
+ requests==2.31.0
7
+ huggingface_hub==0.20.1
8
+ tokenizers==0.15.0
9
+ transformers==4.35.2
10
+ numpy==1.24.3
11
+ pytz==2023.3.post1
shared/chat_history.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import time
4
+ from datetime import datetime
5
+ from typing import List, Dict, Any
6
+ from .models import ChatMessage
7
+
8
+
9
+ def save_chat_history(messages: List[ChatMessage], model_name: str, response: str, filename: str = "chat.md"):
10
+ """
11
+ Save chat history to a markdown file with timestamp and model information
12
+ """
13
+ timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
14
+
15
+ # Prepare the markdown content
16
+ history_content = f"""
17
+ ## Chat Session: {timestamp}
18
+ **Model Used:** {model_name}
19
+
20
+ ---
21
+ """
22
+
23
+ # Add all messages to the markdown file
24
+ for msg in messages:
25
+ role_prefix = "**User:**" if msg.role.lower() == "user" else "**Assistant:**"
26
+ history_content += f"\n{role_prefix} {msg.content}\n\n"
27
+
28
+ # Add the final response from the assistant
29
+ history_content += f"\n**Assistant Response:** {response}\n\n---\n\n"
30
+
31
+ # Append to the chat history file
32
+ with open(filename, "a", encoding="utf-8") as file:
33
+ file.write(history_content)
34
+
35
+
36
+ def save_detailed_chat_log(request_data: Dict[str, Any], response_data: str, model_name: str, processing_time: float, filename: str = "chat.md"):
37
+ """
38
+ Save detailed chat log with metadata
39
+ """
40
+ timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
41
+
42
+ log_content = f"""
43
+ ## Chat Request Log: {timestamp}
44
+ - **Model:** {model_name}
45
+ - **Processing Time:** {processing_time:.2f}s
46
+ - **Max Tokens:** {request_data.get('max_tokens', 512)}
47
+ - **Temperature:** {request_data.get('temperature', 0.8)}
48
+
49
+ ### Input Messages:
50
+ """
51
+
52
+ # Add the messages from the request
53
+ messages = request_data.get('messages', [])
54
+ for msg in messages:
55
+ role = msg.get('role', 'unknown')
56
+ content = msg.get('content', '')
57
+ role_display = "**User**" if role.lower() == 'user' else "**Assistant**"
58
+ log_content += f"- {role_display}: {content}\n"
59
+
60
+ log_content += f"\n### Model Response:\n{response_data}\n\n---\n\n"
61
+
62
+ # Append to the file
63
+ with open(filename, "a", encoding="utf-8") as file:
64
+ file.write(log_content)
65
+
66
+
67
+ def initialize_chat_file(filename: str = "chat.md"):
68
+ """
69
+ Initialize the chat history file with header if it doesn't exist
70
+ """
71
+ if not os.path.exists(filename):
72
+ header = f"""# Chat History
73
+ Last updated: {datetime.now().strftime("%Y-%m-%d %H:%M:%S")}
74
+
75
+ This file contains the history of all chat conversations processed by the multi-node API system.
76
+
77
+ ---
78
+ """
79
+ with open(filename, "w", encoding="utf-8") as file:
80
+ file.write(header)
shared/models.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pydantic import BaseModel
2
+ from typing import List, Optional, Dict, Any
3
+
4
+
5
+ class ChatMessage(BaseModel):
6
+ role: str # "user" or "assistant"
7
+ content: str
8
+
9
+
10
+ class ChatRequest(BaseModel):
11
+ messages: List[ChatMessage]
12
+ model: str = "sam-x-nano"
13
+ max_tokens: Optional[int] = 512
14
+ temperature: Optional[float] = 0.8
15
+ top_k: Optional[int] = 40
16
+ top_p: Optional[float] = 0.9
17
+ repetition_penalty: Optional[float] = 1.1
18
+ stream: Optional[bool] = False
19
+
20
+
21
+ class ChatResponse(BaseModel):
22
+ id: str
23
+ object: str = "chat.completion"
24
+ created: int
25
+ model: str
26
+ choices: List[Dict[str, Any]]
27
+ usage: Dict[str, int]
28
+
29
+
30
+ class WorkerStatus(BaseModel):
31
+ model_name: str
32
+ is_active: bool
33
+ load: float
34
+ last_heartbeat: int
space-config.yaml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ # SACCP Node Space Configuration
2
+ runtime:
3
+ cpu: "medium"
4
+ memory: "16x"
5
+ accelerator: "cpu" # Will be configured based on node type
6
+ env:
7
+ NODE_TYPE: "large"
8
+ MODEL_TYPE: "sam-x-large"
worker_app.py ADDED
@@ -0,0 +1,362 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ import json
4
+ import asyncio
5
+ from datetime import datetime
6
+ from typing import Dict, List, Optional
7
+ from fastapi import FastAPI, HTTPException
8
+ import uvicorn
9
+ from pydantic import BaseModel
10
+ from shared.models import ChatRequest, ChatResponse, ChatMessage
11
+ import tensorflow as tf
12
+ import keras
13
+ import numpy as np
14
+ from tokenizers import Tokenizer
15
+ from huggingface_hub import hf_hub_download
16
+ import requests
17
+ from transformers import GPT2Tokenizer
18
+
19
+ app = FastAPI(
20
+ title="Worker Node for Sam-X Models",
21
+ description="Processing node for Sam-X model inference",
22
+ version="1.0.0"
23
+ )
24
+
25
+ # Global variables for model and tokenizer
26
+ tokenizer = None
27
+ model = None
28
+ model_loaded = False
29
+
30
+ # Configuration
31
+ MODEL_REPO = os.getenv("MODEL_REPO", "Smilyai-labs/Sam-large-2")
32
+ MODEL_TYPE = os.getenv("MODEL_TYPE", "sam-x-nano") # Determines which model to load
33
+ CACHE_DIR = "./model_cache"
34
+
35
+ # Performance optimizations
36
+ NUM_CORES = os.cpu_count() or 4
37
+ os.environ['TF_NUM_INTEROP_THREADS'] = str(NUM_CORES)
38
+ os.environ['TF_NUM_INTRAOP_THREADS'] = str(NUM_CORES)
39
+ os.environ['CUDA_VISIBLE_DEVICES'] = '-1' # Force CPU only
40
+ os.environ['TF_ENABLE_ONEDNN_OPTS'] = '1' # Intel optimization
41
+ os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' # Reduce TF logging
42
+
43
+ # Configure TF threading
44
+ tf.config.threading.set_inter_op_parallelism_threads(NUM_CORES)
45
+ tf.config.threading.set_intra_op_parallelism_threads(NUM_CORES)
46
+
47
+ print(f"✅ CPU optimized: {NUM_CORES} threads, oneDNN enabled")
48
+
49
+
50
+ def load_tokenizer():
51
+ """Load the tokenizer from Hugging Face or local files"""
52
+ global tokenizer
53
+
54
+ print("🚀 Loading tokenizer...")
55
+
56
+ try:
57
+ # Try to load from Hugging Face
58
+ from transformers import AutoTokenizer
59
+ hf_tokenizer = AutoTokenizer.from_pretrained("gpt2")
60
+
61
+ # Add special tokens specific to your models
62
+ special_tokens = ["
63
+ ", "
64
+ ", "
65
+ ", "
66
+ ", "<CONTINUE>", "<im end for model tun>"]
67
+ hf_tokenizer.add_special_tokens({"additional_special_tokens": special_tokens})
68
+
69
+ # Save temporarily to create tokenizers instance
70
+ os.makedirs("./temp_tokenizer", exist_ok=True)
71
+ hf_tokenizer.save_pretrained("./temp_tokenizer")
72
+ tokenizer = Tokenizer.from_file("./temp_tokenizer/tokenizer.json")
73
+
74
+ print(f"✅ Tokenizer loaded with vocab size: {tokenizer.get_vocab_size()}")
75
+
76
+ except Exception as e:
77
+ print(f"❌ Error loading tokenizer: {e}")
78
+ raise
79
+
80
+
81
+ def load_model():
82
+ """Load the specific model based on MODEL_TYPE environment variable"""
83
+ global model, model_loaded
84
+
85
+ print(f"🚀 Loading {MODEL_TYPE} model...")
86
+
87
+ try:
88
+ # Determine which model to load based on MODEL_TYPE
89
+ if MODEL_TYPE == "sam-x-nano":
90
+ # Load nano model
91
+ config_path = hf_hub_download("Smilyai-labs/Sam-nano", "config.json", cache_dir=CACHE_DIR)
92
+ with open(config_path, 'r') as f:
93
+ config = json.load(f)
94
+ elif MODEL_TYPE == "sam-x-mini":
95
+ # Load mini model
96
+ config_path = hf_hub_download("Smilyai-labs/Sam-mini", "config.json", cache_dir=CACHE_DIR)
97
+ with open(config_path, 'r') as f:
98
+ config = json.load(f)
99
+ elif MODEL_TYPE == "sam-x-fast":
100
+ # Load fast model
101
+ config_path = hf_hub_download("Smilyai-labs/Sam-fast", "config.json", cache_dir=CACHE_DIR)
102
+ with open(config_path, 'r') as f:
103
+ config = json.load(f)
104
+ else: # Default to large model
105
+ # Load from the default repo
106
+ config_path = hf_hub_download(MODEL_REPO, "config.json", cache_dir=CACHE_DIR)
107
+ with open(config_path, 'r') as f:
108
+ config = json.load(f)
109
+
110
+ # Build model from config
111
+ model_config = {
112
+ 'vocab_size': config.get('vocab_size', 50432),
113
+ 'd_model': config.get('hidden_size', 768),
114
+ 'n_layers': config.get('num_hidden_layers', 12),
115
+ 'n_heads': config.get('num_attention_heads', 12),
116
+ 'ff_mult': config.get('intermediate_size', 3072) / config.get('hidden_size', 768),
117
+ 'max_len': config.get('max_position_embeddings', 2048),
118
+ 'dropout': 0.1,
119
+ 'rope_theta': config.get('rope_theta', 10000)
120
+ }
121
+
122
+ from model_architecture import SAM1Model # Import from your architecture file
123
+ model = SAM1Model(config=model_config)
124
+
125
+ # Build model with dummy input
126
+ dummy_input = tf.zeros((1, 16), dtype=tf.int32)
127
+ _ = model(dummy_input, training=False, use_cache=False)
128
+
129
+ print(f"✅ Model loaded: {config.get('num_hidden_layers', 12)} layers")
130
+
131
+ # Try to load weights
132
+ try:
133
+ weights_path = hf_hub_download(MODEL_REPO, "model.weights.h5", cache_dir=CACHE_DIR)
134
+ model.load_weights(weights_path)
135
+ print("✅ Model weights loaded successfully!")
136
+ except Exception as e:
137
+ print(f"⚠️ Could not load weights, using random initialization: {e}")
138
+
139
+ # Warm up the model
140
+ print("🔥 Warming up model...")
141
+ warmup_input = tf.constant([[1, 2, 3, 4, 5]], dtype=tf.int32)
142
+ _, _ = model(warmup_input, training=False, use_cache=True)
143
+ print("✅ Model warmed up")
144
+
145
+ model_loaded = True
146
+
147
+ except Exception as e:
148
+ print(f"❌ Error loading model: {e}")
149
+ raise
150
+
151
+
152
+ def format_chat_prompt(messages: List[Dict[str, str]]) -> str:
153
+ """Format chat messages into a prompt for the model"""
154
+ prompt = ""
155
+
156
+ for msg in messages:
157
+ role = msg.get('role', 'user')
158
+ content = msg.get('content', '')
159
+
160
+ if role.lower() == 'user':
161
+ prompt += f"
162
+ {content}
163
+ "
164
+ elif role.lower() == 'assistant':
165
+ prompt += f"
166
+ {content}
167
+ "
168
+ else:
169
+ # System or other roles
170
+ prompt += f"{content}\n"
171
+
172
+ # Add assistant prefix for the response
173
+ prompt += "
174
+ "
175
+
176
+ return prompt
177
+
178
+
179
+ def sample_token(logits, temperature=0.8, top_k=40, top_p=0.9, repetition_penalty=1.1):
180
+ """Sample next token from logits"""
181
+ # Apply temperature
182
+ logits = logits / temperature
183
+
184
+ # Apply repetition penalty
185
+ if repetition_penalty != 1.0:
186
+ logits = np.where(logits < 0, logits * repetition_penalty, logits / repetition_penalty)
187
+
188
+ # Convert to probabilities
189
+ probs = np.exp(logits - np.max(logits)) # Numerical stability
190
+ probs = probs / np.sum(probs)
191
+
192
+ # Top-k filtering
193
+ if top_k > 0 and top_k < len(probs):
194
+ top_k_idx = np.argpartition(probs, -top_k)[-top_k:]
195
+ top_k_probs = probs[top_k_idx]
196
+ top_k_probs = top_k_probs / np.sum(top_k_probs) # Normalize
197
+ sampled_idx = np.random.choice(len(top_k_idx), p=top_k_probs)
198
+ return top_k_idx[sampled_idx]
199
+
200
+ # Top-p (nucleus) sampling
201
+ if top_p < 1.0:
202
+ sorted_idx = np.argsort(probs)[::-1]
203
+ sorted_probs = probs[sorted_idx]
204
+ cumulative_probs = np.cumsum(sorted_probs)
205
+ cutoff_idx = np.searchsorted(cumulative_probs, top_p)
206
+ cutoff_idx = min(cutoff_idx + 1, len(sorted_idx))
207
+
208
+ nucleus_idx = sorted_idx[:cutoff_idx]
209
+ nucleus_probs = probs[nucleus_idx]
210
+ nucleus_probs = nucleus_probs / np.sum(nucleus_probs) # Normalize
211
+ sampled_idx = np.random.choice(len(nucleus_idx), p=nucleus_probs)
212
+ return nucleus_idx[sampled_idx]
213
+
214
+ # Regular sampling
215
+ return np.random.choice(len(probs), p=probs)
216
+
217
+
218
+ def generate_response(prompt: str, max_tokens: int = 512, temperature: float = 0.8,
219
+ top_k: int = 40, top_p: float = 0.9, repetition_penalty: float = 1.1) -> str:
220
+ """Generate response from the model"""
221
+ global model, tokenizer
222
+
223
+ if not model_loaded:
224
+ raise Exception("Model not loaded")
225
+
226
+ # Tokenize the prompt
227
+ prompt_ids = tokenizer.encode(prompt).ids
228
+ input_ids = tf.constant([prompt_ids], dtype=tf.int32)
229
+
230
+ # Run the model
231
+ generated_ids = []
232
+ current_ids = input_ids
233
+
234
+ # Process tokens one by one (simplified generation without KV cache for this example)
235
+ for i in range(max_tokens):
236
+ with tf.device('/CPU:0'): # Use CPU for inference
237
+ logits, _ = model(current_ids, training=False, use_cache=False)
238
+ next_token_logits = logits[0, -1, :].numpy()
239
+
240
+ # Sample next token
241
+ next_token_id = sample_token(next_token_logits, temperature, top_k, top_p, repetition_penalty)
242
+
243
+ # Add to generated sequence
244
+ generated_ids.append(next_token_id)
245
+ current_ids = tf.constant([[next_token_id]], dtype=tf.int32)
246
+
247
+ # Stop if we hit an end token
248
+ if next_token_id in [50256, tokenizer.token_to_id("
249
+ "), tokenizer.token_to_id("<im end for model tun>")]:
250
+ break
251
+
252
+ # Decode the generated tokens
253
+ generated_text = tokenizer.decode(generated_ids)
254
+
255
+ # Clean up the response
256
+ # Remove any end tokens that might have been included
257
+ stop_tokens = ["
258
+ ", "<im end for model tun>"]
259
+ for token in stop_tokens:
260
+ idx = generated_text.find(token)
261
+ if idx != -1:
262
+ generated_text = generated_text[:idx]
263
+
264
+ return generated_text.strip()
265
+
266
+
267
+ @app.on_event("startup")
268
+ def startup_event():
269
+ """Initialize model and tokenizer on startup"""
270
+ global model_loaded
271
+
272
+ print(f"Initializing worker for model type: {MODEL_TYPE}")
273
+
274
+ try:
275
+ load_tokenizer()
276
+ load_model()
277
+ print("✅ Worker initialized successfully!")
278
+ except Exception as e:
279
+ print(f"❌ Worker initialization failed: {e}")
280
+ model_loaded = False
281
+
282
+
283
+ @app.post("/chat/completions")
284
+ async def chat_completions(request: ChatRequest):
285
+ """Process chat completion request"""
286
+ global model_loaded
287
+
288
+ if not model_loaded:
289
+ raise HTTPException(status_code=503, detail="Model not loaded")
290
+
291
+ try:
292
+ # Format the messages into a single prompt
293
+ messages = [{"role": msg.role, "content": msg.content} for msg in request.messages]
294
+ prompt = format_chat_prompt(messages)
295
+
296
+ # Generate response
297
+ start_time = time.time()
298
+ response_text = generate_response(
299
+ prompt=prompt,
300
+ max_tokens=request.max_tokens,
301
+ temperature=request.temperature,
302
+ top_k=request.top_k,
303
+ top_p=request.top_p,
304
+ repetition_penalty=request.repetition_penalty
305
+ )
306
+ processing_time = time.time() - start_time
307
+
308
+ # Create response in OpenAI-compatible format
309
+ response = ChatResponse(
310
+ id=f"chat-{int(time.time())}",
311
+ model=request.model,
312
+ choices=[
313
+ {
314
+ "index": 0,
315
+ "message": {"role": "assistant", "content": response_text},
316
+ "finish_reason": "stop"
317
+ }
318
+ ],
319
+ usage={
320
+ "prompt_tokens": len(prompt),
321
+ "completion_tokens": len(response_text),
322
+ "total_tokens": len(prompt) + len(response_text)
323
+ }
324
+ )
325
+
326
+ print(f"Generated response in {processing_time:.2f}s for model {request.model}")
327
+
328
+ return response.dict()
329
+
330
+ except Exception as e:
331
+ print(f"Error processing request: {e}")
332
+ raise HTTPException(status_code=500, detail=f"Error processing request: {str(e)}")
333
+
334
+
335
+ @app.get("/health")
336
+ async def health_check():
337
+ """Health check endpoint"""
338
+ return {
339
+ "status": "healthy" if model_loaded else "unhealthy",
340
+ "model_type": MODEL_TYPE,
341
+ "model_loaded": model_loaded,
342
+ "timestamp": int(time.time())
343
+ }
344
+
345
+
346
+ @app.get("/model-info")
347
+ async def model_info():
348
+ """Get information about the loaded model"""
349
+ if not model_loaded:
350
+ raise HTTPException(status_code=404, detail="Model not loaded")
351
+
352
+ return {
353
+ "model_type": MODEL_TYPE,
354
+ "vocab_size": tokenizer.get_vocab_size() if tokenizer else 0,
355
+ "parameters": model.count_params() if model else 0,
356
+ "max_context_length": 2048 # Default, would be from config
357
+ }
358
+
359
+
360
+ if __name__ == "__main__":
361
+ port = int(os.getenv("PORT", 8000))
362
+ uvicorn.run(app, host="0.0.0.0", port=port)