|
|
|
|
|
import os
|
|
|
import time
|
|
|
import json
|
|
|
import torch
|
|
|
import logging
|
|
|
from fastapi import FastAPI, HTTPException
|
|
|
from pydantic import BaseModel
|
|
|
from transformers import (
|
|
|
AutoModelForCausalLM,
|
|
|
AutoTokenizer,
|
|
|
BitsAndBytesConfig,
|
|
|
pipeline,
|
|
|
AutoConfig
|
|
|
)
|
|
|
from datetime import datetime
|
|
|
import psutil
|
|
|
import math
|
|
|
|
|
|
|
|
|
logging.basicConfig(
|
|
|
filename='router_tracing.log',
|
|
|
level=logging.INFO,
|
|
|
format='%(asctime)s - %(levelname)s - %(message)s'
|
|
|
)
|
|
|
|
|
|
|
|
|
app = FastAPI(title="System1/System2 Router", version="1.0")
|
|
|
|
|
|
|
|
|
SYSTEM1_MODEL = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
|
|
|
SYSTEM2_MODEL = "HuggingFaceH4/zephyr-7b-beta"
|
|
|
|
|
|
def estimate_model_memory(model_id, bits=4, is_system2=False):
|
|
|
"""Estimate memory requirements for a model with given quantization"""
|
|
|
try:
|
|
|
config = AutoConfig.from_pretrained(model_id)
|
|
|
total_params = sum(p.numel() for p in config.to_dict().values() if isinstance(p, int))
|
|
|
|
|
|
|
|
|
if bits == 4:
|
|
|
bytes_per_param = 0.5
|
|
|
elif bits == 8:
|
|
|
bytes_per_param = 1
|
|
|
else:
|
|
|
bytes_per_param = 2
|
|
|
|
|
|
base_memory = total_params * bytes_per_param
|
|
|
|
|
|
total_memory = base_memory * 1.2
|
|
|
|
|
|
|
|
|
return total_memory / (1024 ** 3)
|
|
|
except:
|
|
|
|
|
|
if is_system2:
|
|
|
return 6.0 if bits == 4 else 14.0
|
|
|
else:
|
|
|
return 1.2 if bits == 8 else 2.5
|
|
|
|
|
|
def get_device_map(model_id, bits=4, is_system2=False):
|
|
|
"""Create an appropriate device map based on available resources"""
|
|
|
cuda_available = torch.cuda.is_available()
|
|
|
total_memory = psutil.virtual_memory().total / (1024 ** 3)
|
|
|
|
|
|
|
|
|
estimated_size = estimate_model_memory(model_id, bits, is_system2)
|
|
|
|
|
|
|
|
|
logging.info(f"System memory: {total_memory:.2f}GB, Estimated model size: {estimated_size:.2f}GB")
|
|
|
|
|
|
if not cuda_available:
|
|
|
logging.warning("No GPU detected - using CPU only")
|
|
|
return "cpu"
|
|
|
|
|
|
|
|
|
gpu_memory = torch.cuda.get_device_properties(0).total_memory / (1024 ** 3)
|
|
|
|
|
|
|
|
|
if estimated_size < gpu_memory * 0.8:
|
|
|
return "auto"
|
|
|
|
|
|
|
|
|
logging.warning(f"Model size ({estimated_size:.2f}GB) exceeds GPU capacity ({gpu_memory:.2f}GB). Using CPU offloading.")
|
|
|
|
|
|
if bits == 4:
|
|
|
|
|
|
return {
|
|
|
"": 0,
|
|
|
"model.layers.0": "cpu",
|
|
|
"model.layers.1": "cpu",
|
|
|
"model.norm": "cpu",
|
|
|
"lm_head": "cpu"
|
|
|
}
|
|
|
else:
|
|
|
|
|
|
return "cpu"
|
|
|
|
|
|
|
|
|
def load_quantized_model(model_id, is_system2=False):
|
|
|
"""Load 4-bit quantized model with proper memory management"""
|
|
|
try:
|
|
|
compute_dtype = torch.bfloat16 if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else torch.float16
|
|
|
device_map = get_device_map(model_id, bits=4 if is_system2 else 8, is_system2=is_system2)
|
|
|
|
|
|
|
|
|
if is_system2:
|
|
|
quant_config = BitsAndBytesConfig(
|
|
|
load_in_4bit=True,
|
|
|
bnb_4bit_quant_type="nf4",
|
|
|
bnb_4bit_compute_dtype=compute_dtype,
|
|
|
bnb_4bit_use_double_quant=True,
|
|
|
llm_int8_enable_fp32_cpu_offload=True
|
|
|
)
|
|
|
else:
|
|
|
quant_config = BitsAndBytesConfig(
|
|
|
load_in_8bit=True,
|
|
|
llm_int8_enable_fp32_cpu_offload=True,
|
|
|
llm_int8_threshold=6.0
|
|
|
)
|
|
|
|
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
|
|
if not tokenizer.pad_token:
|
|
|
tokenizer.pad_token = tokenizer.eos_token
|
|
|
|
|
|
|
|
|
model = AutoModelForCausalLM.from_pretrained(
|
|
|
model_id,
|
|
|
quantization_config=quant_config,
|
|
|
device_map=device_map,
|
|
|
offload_folder="offload_folder",
|
|
|
trust_remote_code=True,
|
|
|
low_cpu_mem_usage=True
|
|
|
)
|
|
|
|
|
|
logging.info(f"Successfully loaded {model_id} with device_map: {device_map}")
|
|
|
return tokenizer, model
|
|
|
|
|
|
except Exception as e:
|
|
|
logging.error(f"Model load failed for {model_id}: {str(e)}")
|
|
|
|
|
|
if "out of memory" in str(e).lower() or "oom" in str(e).lower():
|
|
|
logging.warning("GPU memory insufficient. Falling back to CPU loading.")
|
|
|
try:
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
|
|
if not tokenizer.pad_token:
|
|
|
tokenizer.pad_token = tokenizer.eos_token
|
|
|
|
|
|
model = AutoModelForCausalLM.from_pretrained(
|
|
|
model_id,
|
|
|
device_map="cpu",
|
|
|
trust_remote_code=True
|
|
|
)
|
|
|
logging.info(f"Fallback CPU loading succeeded for {model_id}")
|
|
|
return tokenizer, model
|
|
|
except Exception as cpu_e:
|
|
|
logging.error(f"CPU fallback also failed: {str(cpu_e)}")
|
|
|
raise
|
|
|
raise
|
|
|
|
|
|
|
|
|
print("Loading quantized models with memory optimization...")
|
|
|
tokenizer1, model1 = load_quantized_model(SYSTEM1_MODEL)
|
|
|
tokenizer2, model2 = load_quantized_model(SYSTEM2_MODEL, is_system2=True)
|
|
|
print("Models loaded successfully!")
|
|
|
|
|
|
|
|
|
system1_pipe = pipeline(
|
|
|
"text-generation",
|
|
|
model=model1,
|
|
|
tokenizer=tokenizer1,
|
|
|
max_new_tokens=128,
|
|
|
do_sample=True,
|
|
|
temperature=0.7,
|
|
|
pad_token_id=tokenizer1.eos_token_id,
|
|
|
device_map="auto"
|
|
|
)
|
|
|
|
|
|
system2_pipe = pipeline(
|
|
|
"text-generation",
|
|
|
model=model2,
|
|
|
tokenizer=tokenizer2,
|
|
|
max_new_tokens=256,
|
|
|
do_sample=True,
|
|
|
temperature=0.8,
|
|
|
pad_token_id=tokenizer2.eos_token_id,
|
|
|
device_map="auto",
|
|
|
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
|
|
|
)
|
|
|
|
|
|
|
|
|
COMPLEX_KEYWORDS = {
|
|
|
'explain', 'why', 'how', 'compare', 'analyze', 'reason',
|
|
|
'steps', 'detailed', 'difference', 'advantage', 'disadvantage',
|
|
|
'calculate', 'derive', 'formula', 'math', 'equation'
|
|
|
}
|
|
|
|
|
|
def is_semantically_complex(query: str) -> bool:
|
|
|
"""Rule-based semantic complexity check"""
|
|
|
lower_query = query.lower()
|
|
|
tokens = lower_query.split()
|
|
|
|
|
|
|
|
|
if any(keyword in lower_query for keyword in COMPLEX_KEYWORDS):
|
|
|
return True
|
|
|
|
|
|
|
|
|
if len(tokens) > 15:
|
|
|
return True
|
|
|
|
|
|
|
|
|
if any(pattern in lower_query for pattern in ['vs', 'versus', 'pros and cons', 'advantages and disadvantages', 'compare and contrast']):
|
|
|
return True
|
|
|
|
|
|
return False
|
|
|
|
|
|
def calculate_entropy(response):
|
|
|
"""Calculate average token entropy from generation scores"""
|
|
|
try:
|
|
|
|
|
|
if isinstance(response, dict):
|
|
|
scores = response.get("scores", [])
|
|
|
elif isinstance(response, list) and len(response) > 0:
|
|
|
scores = response[0].get("scores", [])
|
|
|
else:
|
|
|
scores = []
|
|
|
|
|
|
entropies = []
|
|
|
|
|
|
for item in scores:
|
|
|
|
|
|
if isinstance(item, tuple):
|
|
|
logits = item[0]
|
|
|
elif hasattr(item, 'logits'):
|
|
|
logits = item.logits
|
|
|
else:
|
|
|
logits = item
|
|
|
|
|
|
if isinstance(logits, torch.Tensor):
|
|
|
probs = torch.softmax(logits, dim=-1)
|
|
|
entropy = -torch.sum(probs * torch.log(probs + 1e-10)).item()
|
|
|
entropies.append(entropy)
|
|
|
|
|
|
return sum(entropies) / len(entropies) if entropies else 0
|
|
|
except Exception as e:
|
|
|
logging.warning(f"Entropy calculation failed: {str(e)}")
|
|
|
return 0
|
|
|
|
|
|
|
|
|
class QueryRequest(BaseModel):
|
|
|
text: str
|
|
|
|
|
|
class RouterResponse(BaseModel):
|
|
|
response: str
|
|
|
model_used: str
|
|
|
routing_reason: str
|
|
|
latency_ms: float
|
|
|
entropy: float = None
|
|
|
|
|
|
|
|
|
@app.post("/query", response_model=RouterResponse)
|
|
|
async def route_query(request: QueryRequest):
|
|
|
start_time = time.time()
|
|
|
query = request.text.strip()
|
|
|
|
|
|
if not query:
|
|
|
raise HTTPException(status_code=400, detail="Empty query")
|
|
|
|
|
|
routing_reason = ""
|
|
|
entropy = 0.0
|
|
|
used_model = ""
|
|
|
result_text = ""
|
|
|
|
|
|
try:
|
|
|
|
|
|
if is_semantically_complex(query):
|
|
|
routing_reason = "semantic_complexity"
|
|
|
response = system2_pipe(query, max_new_tokens=150)
|
|
|
used_model = "system2"
|
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
formatted_query = f"<|system|>\nYou are a helpful assistant.</s>\n<|user|>\n{query}</s>\n<|assistant|>\n"
|
|
|
|
|
|
|
|
|
response = system1_pipe(
|
|
|
formatted_query,
|
|
|
max_new_tokens=100,
|
|
|
return_full_text=False,
|
|
|
pad_token_id=tokenizer1.eos_token_id
|
|
|
)
|
|
|
|
|
|
|
|
|
if isinstance(response, list) and len(response) > 0:
|
|
|
generated_text = response[0]['generated_text'].strip()
|
|
|
|
|
|
result_text = generated_text.replace(formatted_query, "").strip()
|
|
|
else:
|
|
|
result_text = str(response).strip()
|
|
|
|
|
|
used_model = "system1"
|
|
|
routing_reason = "simple_query"
|
|
|
|
|
|
|
|
|
if not result_text and used_model == "system2":
|
|
|
if isinstance(response, list) and len(response) > 0:
|
|
|
result_text = response[0]['generated_text'].replace(query, "", 1).strip()
|
|
|
|
|
|
|
|
|
trace_data = {
|
|
|
"timestamp": datetime.utcnow().isoformat(),
|
|
|
"query": query,
|
|
|
"model_used": used_model,
|
|
|
"routing_reason": routing_reason,
|
|
|
"entropy": entropy,
|
|
|
"response": result_text
|
|
|
}
|
|
|
logging.info(json.dumps(trace_data))
|
|
|
|
|
|
|
|
|
latency_ms = (time.time() - start_time) * 1000
|
|
|
|
|
|
return RouterResponse(
|
|
|
response=result_text,
|
|
|
model_used=used_model,
|
|
|
routing_reason=routing_reason,
|
|
|
latency_ms=round(latency_ms, 2),
|
|
|
entropy=round(entropy, 2)
|
|
|
)
|
|
|
|
|
|
except Exception as e:
|
|
|
error_msg = f"Processing error for query '{query[:20]}...': {str(e)}"
|
|
|
logging.error(error_msg)
|
|
|
|
|
|
return RouterResponse(
|
|
|
response="I apologize, but I encountered an error processing your request. Please try again with a simpler query.",
|
|
|
model_used="error_fallback",
|
|
|
routing_reason="error_recovery",
|
|
|
latency_ms=round((time.time() - start_time) * 1000, 2),
|
|
|
entropy=0.0
|
|
|
)
|
|
|
|
|
|
|
|
|
@app.get("/health")
|
|
|
async def health_check():
|
|
|
gpu_memory = 0
|
|
|
if torch.cuda.is_available():
|
|
|
gpu_memory = torch.cuda.get_device_properties(0).total_memory / (1024 ** 3)
|
|
|
|
|
|
return {
|
|
|
"status": "healthy",
|
|
|
"system1": SYSTEM1_MODEL,
|
|
|
"system2": SYSTEM2_MODEL,
|
|
|
"device": "cuda" if torch.cuda.is_available() else "cpu",
|
|
|
"gpu_memory_gb": round(gpu_memory, 2),
|
|
|
"cpu_memory_gb": round(psutil.virtual_memory().total / (1024 ** 3), 2)
|
|
|
}
|
|
|
|
|
|
|
|
|
@app.post("/warmup")
|
|
|
async def warmup_models():
|
|
|
try:
|
|
|
|
|
|
system1_pipe("Hello, how are you?", max_new_tokens=10)
|
|
|
|
|
|
|
|
|
system2_pipe("What is the capital of France?", max_new_tokens=10)
|
|
|
|
|
|
return {"status": "models warmed up successfully"}
|
|
|
except Exception as e:
|
|
|
return {"status": "warmup failed", "error": str(e)}
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
import uvicorn
|
|
|
print("Starting server with memory-optimized configuration...")
|
|
|
uvicorn.run(app, host="0.0.0.0", port=8000, log_level="info") |