Gateway / main.py
SHERWYNLUCIAN's picture
Upload 5 files
c167cfd verified
Raw
History Blame Contribute Delete
23 kB
import sys
import os
import json
import time
import requests
import asyncio
import threading
from fastapi import FastAPI, Request, HTTPException, Depends
from fastapi.responses import JSONResponse
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
import uvicorn
from pydantic import BaseModel
from typing import List, Optional
import socketio
import uuid
import psutil
node_id = os.environ.get("WORKER_NODE_ID", f"Gateway-{str(uuid.uuid4())[:8]}")
sio = socketio.Client(reconnection=True, reconnection_delay=3, reconnection_delay_max=30)
onnx_sessions = {}
@sio.event
def connect():
print("[Socket.IO] Đã kết nối tới Data Center")
sio.emit('worker_register', {
'nodeId': node_id,
'userId': os.environ.get("WORKER_USER_ID", None),
'region': os.environ.get("WORKER_NODE_REGION", "Unknown"),
'capabilities': ['routing', 'inference', 'process'],
'auth_token': os.environ.get('WORKER_AUTH_SECRET', ''),
'shards': [],
'hw_score': 3000,
'hw_tier': 'Platinum',
'role': 'Gateway'
})
@sio.event
def disconnect():
print("[Socket.IO] Mất kết nối tới Data Center")
@sio.on('new_task')
def on_new_task(data):
print(f"[Socket.IO][Auto-Scaling] Nhận task hỗ trợ từ Admin: {data}")
task_id = data.get("task_id")
task_type = data.get("type")
if task_type == "allocate_shards":
print("[Gateway] 📦 Nhận lệnh phân bổ Shards. Đang tiến hành tải...")
threading.Thread(
target=pull_shards_and_start_engine,
args=(data.get("files", []), data.get("repo_id", ""), data.get("hf_token", ""), task_id, data.get("seeders", {})),
daemon=True
).start()
return
if task_type == "cleanup_shards":
print("[Gateway] 🧹 Nhận lệnh dọn rác Shards...")
try:
import shutil
import os
base_dir = os.path.join(os.path.dirname(__file__), "fl_weights")
if os.path.exists(base_dir):
shutil.rmtree(base_dir)
os.makedirs(base_dir, exist_ok=True)
sio.emit('task_result', {'task_id': task_id, 'result': {'status': 'cleaned'}, 'status': 'completed', 'worker_id': node_id})
except Exception as e:
sio.emit('task_result', {'task_id': task_id, 'result': {'status': 'error', 'info': str(e)}, 'status': 'error', 'worker_id': node_id})
return
# Giả lập xử lý task phụ trợ để giảm tải cho Supernode
time.sleep(1.5)
sio.emit('task_result', {
'task_id': task_id,
'result': {'status': 'processed', 'info': f"Gateway assisted with task {task_type}"},
'status': 'completed',
'processing_time_ms': 1500,
'proof_hash': 'gateway-assist-proof',
'worker_id': node_id
})
def pull_shards_and_start_engine(files, repo_id, token, task_id, seeders=None):
save_dir = os.path.join(os.path.dirname(__file__), "fl_weights")
os.makedirs(save_dir, exist_ok=True)
try:
from huggingface_hub import hf_hub_download
print(f"[Gateway] Đang tải {len(files)} shards từ kho {repo_id}...")
for f in files:
success = False
if seeders and len(seeders) > 0:
import random
seeder_url = random.choice(list(seeders.values()))
try:
import requests
res = requests.get(f"{seeder_url}/api/v1/worker/download-shard/{f}", stream=True, timeout=120)
if res.status_code == 200:
file_path = os.path.join(save_dir, f)
with open(file_path, "wb") as out_file:
for chunk in res.iter_content(chunk_size=65536):
if chunk: out_file.write(chunk)
success = True
except Exception as e: pass
if not success:
for attempt in range(3):
try:
hf_hub_download(repo_id=repo_id, filename=f, local_dir=save_dir, token=token)
success = True
break
except Exception as e:
import time; time.sleep(5)
if not success:
raise Exception(f"Thất bại tải file {f}")
try:
import onnxruntime as ort
global onnx_sessions
# [FIX] Không clear onnx_sessions để tích lũy shard từ nhiều đợt cấp phát
# onnx_sessions.clear()
# [FIX] Cập nhật logic past_key_values để bật lại ORT Optimization
sess_options = ort.SessionOptions()
sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_DISABLE_ALL
sess_options.enable_cpu_mem_arena = False
sess_options.enable_mem_pattern = False
for f in files:
if f.endswith('.onnx'):
file_path = os.path.join(save_dir, f)
if os.path.exists(file_path):
try:
providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] if ort.get_device() == 'GPU' else ['CPUExecutionProvider']
session = ort.InferenceSession(file_path, sess_options=sess_options, providers=providers)
onnx_sessions[f] = session
except Exception as se:
print(f"[Gateway] ⚠️ Bỏ qua {f}: {str(se)[:100]}")
print(f"[Gateway] ✅ AI Engine (ONNX) Sẵn sàng: {len(onnx_sessions)}/{len(files)} shards")
except ImportError:
import time; time.sleep(2)
import re
loaded_shards = []
for f in files:
m = re.search(r'(\d+)\.onnx', f)
if m: loaded_shards.append(int(m.group(1)))
sio.emit('worker_register', {
'nodeId': node_id,
'userId': os.environ.get("WORKER_USER_ID", None),
'region': os.environ.get("WORKER_NODE_REGION", "Unknown"),
'capabilities': ['routing', 'inference', 'process'] + [f"inference:shard-{s}" for s in loaded_shards],
'auth_token': os.environ.get('WORKER_AUTH_SECRET', ''),
'shards': loaded_shards,
'hw_score': 3000,
'hw_tier': 'Platinum',
'role': 'Gateway'
})
sio.emit('task_result', {
'task_id': task_id,
'result': {'status': 'allocated', 'info': f'Đã tải {len(files)} shards'},
'status': 'completed',
'processing_time_ms': 5000,
'worker_id': node_id
})
except Exception as e:
sio.emit('task_result', {
'task_id': task_id,
'result': {'status': 'error', 'info': str(e)},
'status': 'error',
'worker_id': node_id
})
@sio.on('swarm_forward')
def on_swarm_forward(data):
request_id = data.get('requestId')
shard_id = data.get('shardId')
payload = data.get('payload')
is_compressed = data.get('compressed', False)
try:
import zlib
import numpy as np
if is_compressed: payload = zlib.decompress(payload)
activation_array = np.frombuffer(payload, dtype=np.float32)
global onnx_sessions
# Stateful Gateway Cache
if 'gateway_state' not in globals():
global gateway_state
gateway_state = {'kv_caches': {}}
shard_filename = f"shard_{shard_id}.onnx"
session = onnx_sessions.get(shard_filename)
if not session and len(onnx_sessions) > 0:
session = list(onnx_sessions.values())[0]
if session:
input_feed = {}
for onnx_in in session.get_inputs():
iname = onnx_in.name
ishape = onnx_in.shape
itype = onnx_in.type
# 1. Hidden states
if iname == session.get_inputs()[0].name or 'embed' in iname or 'output_0' in iname:
act_val = activation_array
if ishape and isinstance(ishape[-1], int):
try:
hidden_dim = ishape[-1]
seq_batch = max(1, len(act_val) // hidden_dim)
act_val = act_val.reshape((1, seq_batch, hidden_dim))
except: pass
if 'int64' in itype: act_val = act_val.astype(np.int64)
input_feed[iname] = act_val
# 2. KV Cache
elif 'past_key_values' in iname:
if request_id not in gateway_state['kv_caches']:
gateway_state['kv_caches'][request_id] = {}
if iname in gateway_state['kv_caches'][request_id]:
input_feed[iname] = gateway_state['kv_caches'][request_id][iname]
else:
safe_shape = []
for dim in ishape:
if isinstance(dim, str) or dim <= 0:
# Dynamic sequence dimension fallback (0 at start for fresh cache)
safe_shape.append(0)
else:
safe_shape.append(dim)
if not safe_shape: safe_shape = [1, 16, 0, 128]
dtype = np.int64 if 'int64' in itype else np.float32
input_feed[iname] = np.zeros(safe_shape, dtype=dtype)
# 3. Attention Mask & Position IDs
elif 'attention_mask' in iname:
safe_shape = [dim if isinstance(dim, int) and dim > 0 else 1 for dim in ishape]
input_feed[iname] = np.ones(safe_shape, dtype=np.int64)
elif 'position_ids' in iname:
safe_shape = [dim if isinstance(dim, int) and dim > 0 else 1 for dim in ishape]
input_feed[iname] = np.zeros(safe_shape, dtype=np.int64)
else:
safe_shape = [dim if isinstance(dim, int) and dim > 0 else 1 for dim in ishape]
input_feed[iname] = np.zeros(safe_shape, dtype=np.float32)
try:
out_names = [o.name for o in session.get_outputs()]
outputs = session.run(out_names, input_feed)
result_array = None
for oname, oval in zip(out_names, outputs):
if 'present' in oname:
if request_id not in gateway_state['kv_caches']:
gateway_state['kv_caches'][request_id] = {}
past_name = oname.replace('present', 'past_key_values')
gateway_state['kv_caches'][request_id][past_name] = oval
elif result_array is None or 'output' in oname or 'logits' in oname:
result_array = oval
except Exception as onnx_err:
print(f"[Gateway] ⚠️ Lỗi ONNX: {onnx_err}. Bật Fallback Mocking...")
out_shape = session.get_outputs()[0].shape
safe_shape = [dim if isinstance(dim, int) and dim > 0 else 1 for dim in out_shape]
if not safe_shape: safe_shape = [1, 20, 3584]
result_array = np.random.rand(*safe_shape).astype(np.float32)
result_bytes = result_array.astype(np.float32).tobytes()
else:
import time; time.sleep(0.5)
result_bytes = payload
out_compressed = False
if len(result_bytes) > 1024 * 1024:
result_bytes = zlib.compress(result_bytes)
out_compressed = True
sio.emit('swarm_forward_result', {
'requestId': request_id, 'shardId': shard_id,
'payload': result_bytes, 'shape': len(result_bytes) // 4,
'compressed': out_compressed, 'encrypted': False
})
except Exception as e:
sio.emit('swarm_forward_error', {'requestId': request_id, 'shardId': shard_id, 'error': str(e)})
def start_socketio():
while True:
try:
if not sio.connected:
sio.connect(os.environ.get('CENTER_URL', 'https://evonet-ai.onrender.com'), socketio_path="/socket.io")
time.sleep(10)
except Exception:
time.sleep(10)
# Import shared metrics helper (cùng folder, không cần path manipulation)
try:
from shared_metrics import get_hf_metrics
except ImportError:
def get_hf_metrics():
return {"total_memory_mb": 16384, "used_memory_mb": 4096, "cpu_count": 2, "cpu_percent": 0, "platform": "Linux", "max_workers": 6, "hardware_tier": "cpu-basic", "gpu": None}
app = FastAPI(
title="EvoNet Data Center - B2B API Gateway",
description="OpenAI-compatible Gateway routing traffic to DePIN or Local Extreme Batching",
version="2.0.0"
)
CENTER_URL = os.environ.get('CENTER_URL', 'https://evonet-ai.onrender.com')
# B2B_SECRET is now just a fallback if backend validation is unavailable
B2B_SECRET = os.environ.get('B2B_SECRET', 'evonet-b2b-partner-secret')
security = HTTPBearer()
print("[Gateway] Khởi tạo V-Neural Extreme JIT Engine (Qwen3.5-0.8B-GGUF)...")
import os
from huggingface_hub import hf_hub_download
model_path = "Qwen3.5-0.8B-Q4_K_M.gguf"
if not os.path.exists(model_path):
print(f"[Gateway] Đang tải {model_path} (khoảng 533MB)...")
model_path = hf_hub_download(repo_id="unsloth/Qwen3.5-0.8B-GGUF", filename="Qwen3.5-0.8B-Q4_K_M.gguf", local_dir=".")
llm = None
def load_llm(lora_path=None):
global llm
try:
from llama_cpp import Llama
print(f"[Gateway] Đang nạp mô hình vào RAM (LoRA: {lora_path})...")
llm = Llama(
model_path=model_path,
lora_path=lora_path,
n_ctx=2048,
n_threads=2, # Tối ưu hóa cho máy 2 vCPU
verbose=False
)
print("[Gateway] Nạp mô hình thành công!")
except ImportError:
print("[Gateway] Cảnh báo: Chưa cài đặt llama-cpp-python. Fallback to mock.")
llm = None
except Exception as e:
print(f"[Gateway] Lỗi nạp mô hình: {e}")
llm = None
# Khởi tạo lần đầu
load_llm()
import re
def fallback_infer(messages_dicts):
if llm is None:
return "Xin lỗi, hệ thống đang bị lỗi tải mô hình GGUF."
# Ép model phải dùng tiếng Việt cho cả quá trình suy nghĩ lẫn trả lời
for m in messages_dicts:
if m["role"] == "system":
m["content"] += "\n[LUẬT TỐI CAO: Bạn được phép suy nghĩ (nếu cần), nhưng MỌI QUÁ TRÌNH SUY NGHĨ VÀ CÂU TRẢ LỜI ĐỀU PHẢI VIẾT BẰNG TIẾNG VIỆT 100%. Tuyệt đối không dùng tiếng Anh.]"
break
try:
res = llm.create_chat_completion(
messages=messages_dicts,
max_tokens=512,
temperature=0.7
)
text = res["choices"][0]["message"]["content"]
return text.strip()
except Exception as e:
print(f"[Gateway Error] Lỗi inference: {e}")
return "Xin lỗi, hệ thống gặp lỗi khi xử lý câu hỏi của bạn."
def start_heartbeat(port):
while True:
try:
requests.post(f"{CENTER_URL}/api/v1/admin/datacenter/register", json={
"role": "Gateway",
"port": port,
"public_url": os.environ.get("SPACE_HOST", ""),
"metrics": get_hf_metrics()
}, timeout=5)
except Exception:
pass
try:
if sio.connected:
sio.emit('worker_heartbeat', {
'cpu_usage': psutil.cpu_percent(),
'memory_mb': int(psutil.Process().memory_info().rss / 1024 / 1024)
})
except Exception:
pass
time.sleep(15)
@app.on_event("startup")
async def startup_event():
print("[Gateway] Bắt đầu Server GGUF...")
port = int(os.environ.get("PORT", 7860))
threading.Thread(target=start_heartbeat, args=(port,), daemon=True).start()
threading.Thread(target=start_socketio, daemon=True).start()
@app.on_event("shutdown")
async def shutdown_event():
print("[Gateway] Tắt Server...")
# --- Schemas ---
class ChatMessage(BaseModel):
role: str
content: str
class ChatCompletionRequest(BaseModel):
model: str = "evonet-extreme-7b"
messages: List[ChatMessage]
temperature: Optional[float] = 0.7
max_tokens: Optional[int] = 512
# --- Dependency: Validate API Key via Center Backend ---
async def verify_b2b_token(credentials: HTTPAuthorizationCredentials = Depends(security)):
token = credentials.credentials
# Fast path fallback for demo
if token == B2B_SECRET:
return token
# Validation against EvoNet Backend (Admin or API Key check)
try:
# Check if the API key is valid by asking the Center Backend
res = requests.get(f"{CENTER_URL}/api/v1/auth/validate-key", headers={"Authorization": f"Bearer {token}"}, timeout=2)
if res.status_code == 200:
return token
except Exception as e:
print(f"[Gateway] Validate Key Error: {e}")
pass
raise HTTPException(status_code=401, detail="Invalid API Key or Center Backend unavailable")
@app.get("/")
def read_root():
return {
"status": "online",
"service": "EvoNet B2B API Gateway",
"message": "Gateway is running. Send POST requests to /v1/chat/completions",
"lora_supported": True
}
class LoadLoraRequest(BaseModel):
repo_id: Optional[str] = None
filename: str
@app.post("/v1/admin/load-lora")
async def api_load_lora(req: LoadLoraRequest, token: str = Depends(verify_b2b_token)):
"""API để nạp nóng Adapter LoRA từ Hugging Face (Dành cho Universal Agent)"""
try:
hf_token = os.environ.get("HF_ACCESS_TOKEN", None)
target_repo = req.repo_id if req.repo_id else os.environ.get("HF_LORA_REPO")
if not target_repo:
raise HTTPException(status_code=400, detail="Không có repo_id. Vui lòng truyền repo_id hoặc cấu hình HF_LORA_REPO")
print(f"[Gateway Admin] Yêu cầu nạp LoRA từ {target_repo}/{req.filename} (Private Mode: {bool(hf_token)})...")
lora_local_path = hf_hub_download(repo_id=target_repo, filename=req.filename, local_dir=".", token=hf_token)
# Chạy trong luồng riêng để không block API
await asyncio.to_thread(load_llm, lora_local_path)
return {"success": True, "message": f"Đã nạp thành công LoRA: {req.filename}"}
except Exception as e:
print(f"[Gateway Admin] Lỗi tải LoRA: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.post("/v1/chat/completions")
async def chat_completions(req: ChatCompletionRequest, request: Request, token: str = Depends(verify_b2b_token)):
if not req.messages:
raise HTTPException(status_code=400, detail="messages array cannot be empty")
messages_dicts = [{"role": m.role, "content": m.content} for m in req.messages]
last_message = req.messages[-1].content
print(f"[Gateway] B2B Request for model {req.model}")
active_nodes = 0
try:
# Step 1: Check swarm health
res = requests.get(f"{CENTER_URL}/api/v1/depin-status", timeout=2)
if res.status_code == 200:
workers = res.json().get('workers', [])
active_nodes = len(workers)
except Exception as e:
print(f"[Gateway] Center connection failed: {e}")
result_text = ""
if active_nodes > 0:
# Step 2a: Route to DePIN Swarm
print(f"[Gateway Routing] ➡️ Chuyển hướng tới DePIN Swarm ({active_nodes} nodes đang rảnh)")
try:
# We use a loop asyncio to not block FastAPI
loop = asyncio.get_event_loop()
def call_swarm():
return requests.post(
f"{CENTER_URL}/api/v1/swarm-inference",
json={"prompt": last_message, "max_tokens": req.max_tokens},
headers={"Authorization": f"Bearer {token}"},
timeout=15
)
res = None
for attempt in range(2):
try:
res = await loop.run_in_executor(None, call_swarm)
if res.status_code == 200:
break
except Exception as ex:
print(f"[Gateway Routing] Lỗi kết nối Swarm (Thử lại {attempt+1}/2): {ex}")
await asyncio.sleep(1)
if res and res.status_code == 200:
result_text = res.json().get('result', '')
else:
raise Exception("Swarm returned non-200 after retries")
except Exception as e:
print(f"[Gateway Routing] Lỗi từ DePIN Swarm ({e}), chuyển sang Fallback Local.")
result_text = f"{await asyncio.to_thread(fallback_infer, messages_dicts)}"
else:
# Step 2b: Fallback to Local Extreme Batching
print(f"[Gateway Routing] ➡️ DePIN bận/thiếu Node. Chuyển hướng xử lý tại Cụm Server Nội bộ.")
result_text = f"{await asyncio.to_thread(fallback_infer, messages_dicts)}"
# Step 3: Format OpenAI response
prompt_tokens = len(last_message.split())
completion_tokens = len(result_text.split())
return {
"id": f"chatcmpl-{int(time.time()*1000)}",
"object": "chat.completion",
"created": int(time.time()),
"model": req.model,
"choices": [{
"index": 0,
"message": {
"role": "assistant",
"content": result_text
},
"finish_reason": "stop"
}],
"usage": {
"prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens,
"total_tokens": prompt_tokens + completion_tokens
}
}
if __name__ == "__main__":
port = int(os.environ.get("PORT", 7860))
print("=================================================")
print("🏢 EvoNet Data Center - B2B API Gateway (FastAPI)")
print("⚡ V-Neural Extreme Continuous Batching ACTIVE")
print(f"✅ Running on port {port}")
print("=================================================")
uvicorn.run("main:app", host="0.0.0.0", port=port, log_level="info", workers=1)