Aditi132's picture
Upload 2 files
01dc3a8 verified
# router_app.py
import os
import time
import json
import torch
import logging
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
BitsAndBytesConfig,
pipeline,
AutoConfig
)
from datetime import datetime
import psutil
import math
# Configure logging
logging.basicConfig(
filename='router_tracing.log',
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s'
)
# Initialize FastAPI app
app = FastAPI(title="System1/System2 Router", version="1.0")
# Free & Open Models (Apache 2.0 / MIT licensed)
SYSTEM1_MODEL = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" # 1.1B proxy for distilled Llama3-1B
SYSTEM2_MODEL = "HuggingFaceH4/zephyr-7b-beta" # 7B proxy for Llama3-8B
def estimate_model_memory(model_id, bits=4, is_system2=False):
"""Estimate memory requirements for a model with given quantization"""
try:
config = AutoConfig.from_pretrained(model_id)
total_params = sum(p.numel() for p in config.to_dict().values() if isinstance(p, int))
# Rough estimation: params * bytes per param + overhead
if bits == 4:
bytes_per_param = 0.5 # 4-bit = 0.5 bytes
elif bits == 8:
bytes_per_param = 1
else:
bytes_per_param = 2 # For 16-bit
base_memory = total_params * bytes_per_param
# Add 20% overhead for activations and other components
total_memory = base_memory * 1.2
# Convert to GB
return total_memory / (1024 ** 3)
except:
# Fallback estimates
if is_system2:
return 6.0 if bits == 4 else 14.0 # 7B model
else:
return 1.2 if bits == 8 else 2.5 # 1.1B model
def get_device_map(model_id, bits=4, is_system2=False):
"""Create an appropriate device map based on available resources"""
cuda_available = torch.cuda.is_available()
total_memory = psutil.virtual_memory().total / (1024 ** 3) # GB
# Estimate model size
estimated_size = estimate_model_memory(model_id, bits, is_system2)
# Log resource situation
logging.info(f"System memory: {total_memory:.2f}GB, Estimated model size: {estimated_size:.2f}GB")
if not cuda_available:
logging.warning("No GPU detected - using CPU only")
return "cpu"
# Get GPU memory
gpu_memory = torch.cuda.get_device_properties(0).total_memory / (1024 ** 3) # GB
# Check if model fits on GPU
if estimated_size < gpu_memory * 0.8: # Keep 20% buffer
return "auto"
# Create custom device map for CPU offloading
logging.warning(f"Model size ({estimated_size:.2f}GB) exceeds GPU capacity ({gpu_memory:.2f}GB). Using CPU offloading.")
if bits == 4:
# For 4-bit models, we can keep most layers on GPU but offload some
return {
"": 0, # Default to GPU for most layers
"model.layers.0": "cpu", # Offload first layer to CPU
"model.layers.1": "cpu", # Offload second layer to CPU
"model.norm": "cpu", # Offload normalization to CPU
"lm_head": "cpu" # Offload head to CPU
}
else:
# For 8-bit models, more aggressive offloading
return "cpu"
# Model loading with quantization
def load_quantized_model(model_id, is_system2=False):
"""Load 4-bit quantized model with proper memory management"""
try:
compute_dtype = torch.bfloat16 if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else torch.float16
device_map = get_device_map(model_id, bits=4 if is_system2 else 8, is_system2=is_system2)
# Create quantization config
if is_system2:
quant_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=compute_dtype,
bnb_4bit_use_double_quant=True,
llm_int8_enable_fp32_cpu_offload=True # Enable CPU offloading
)
else:
quant_config = BitsAndBytesConfig(
load_in_8bit=True,
llm_int8_enable_fp32_cpu_offload=True, # Enable CPU offloading
llm_int8_threshold=6.0
)
# Load tokenizer first
tokenizer = AutoTokenizer.from_pretrained(model_id)
if not tokenizer.pad_token:
tokenizer.pad_token = tokenizer.eos_token
# Load model with device mapping and CPU offloading
model = AutoModelForCausalLM.from_pretrained(
model_id,
quantization_config=quant_config,
device_map=device_map,
offload_folder="offload_folder",
trust_remote_code=True,
low_cpu_mem_usage=True
)
logging.info(f"Successfully loaded {model_id} with device_map: {device_map}")
return tokenizer, model
except Exception as e:
logging.error(f"Model load failed for {model_id}: {str(e)}")
# Fallback to CPU if GPU loading fails
if "out of memory" in str(e).lower() or "oom" in str(e).lower():
logging.warning("GPU memory insufficient. Falling back to CPU loading.")
try:
tokenizer = AutoTokenizer.from_pretrained(model_id)
if not tokenizer.pad_token:
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(
model_id,
device_map="cpu",
trust_remote_code=True
)
logging.info(f"Fallback CPU loading succeeded for {model_id}")
return tokenizer, model
except Exception as cpu_e:
logging.error(f"CPU fallback also failed: {str(cpu_e)}")
raise
raise
# Load models at startup with better memory management
print("Loading quantized models with memory optimization...")
tokenizer1, model1 = load_quantized_model(SYSTEM1_MODEL)
tokenizer2, model2 = load_quantized_model(SYSTEM2_MODEL, is_system2=True)
print("Models loaded successfully!")
# Pipeline generators with memory-efficient settings
system1_pipe = pipeline(
"text-generation",
model=model1,
tokenizer=tokenizer1,
max_new_tokens=128,
do_sample=True,
temperature=0.7,
pad_token_id=tokenizer1.eos_token_id,
device_map="auto"
)
system2_pipe = pipeline(
"text-generation",
model=model2,
tokenizer=tokenizer2,
max_new_tokens=256,
do_sample=True,
temperature=0.8,
pad_token_id=tokenizer2.eos_token_id,
device_map="auto",
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
)
# Router components
COMPLEX_KEYWORDS = {
'explain', 'why', 'how', 'compare', 'analyze', 'reason',
'steps', 'detailed', 'difference', 'advantage', 'disadvantage',
'calculate', 'derive', 'formula', 'math', 'equation'
}
def is_semantically_complex(query: str) -> bool:
"""Rule-based semantic complexity check"""
lower_query = query.lower()
tokens = lower_query.split()
# Check for complex keywords
if any(keyword in lower_query for keyword in COMPLEX_KEYWORDS):
return True
# Check length threshold (tokens)
if len(tokens) > 15: # Lowered threshold for better routing
return True
# Check question complexity patterns
if any(pattern in lower_query for pattern in ['vs', 'versus', 'pros and cons', 'advantages and disadvantages', 'compare and contrast']):
return True
return False
def calculate_entropy(response):
"""Calculate average token entropy from generation scores"""
try:
# Handle different pipeline outputs
if isinstance(response, dict):
scores = response.get("scores", [])
elif isinstance(response, list) and len(response) > 0:
scores = response[0].get("scores", [])
else:
scores = []
entropies = []
for item in scores:
# Handle different score formats
if isinstance(item, tuple):
logits = item[0]
elif hasattr(item, 'logits'):
logits = item.logits
else:
logits = item
if isinstance(logits, torch.Tensor):
probs = torch.softmax(logits, dim=-1)
entropy = -torch.sum(probs * torch.log(probs + 1e-10)).item()
entropies.append(entropy)
return sum(entropies) / len(entropies) if entropies else 0
except Exception as e:
logging.warning(f"Entropy calculation failed: {str(e)}")
return 0 # Fallback if scores unavailable
# Request/Response models
class QueryRequest(BaseModel):
text: str
class RouterResponse(BaseModel):
response: str
model_used: str
routing_reason: str
latency_ms: float
entropy: float = None
# Core routing logic
@app.post("/query", response_model=RouterResponse)
async def route_query(request: QueryRequest):
start_time = time.time()
query = request.text.strip()
if not query:
raise HTTPException(status_code=400, detail="Empty query")
routing_reason = ""
entropy = 0.0
used_model = ""
result_text = ""
try:
# Step 1: Semantic complexity check
if is_semantically_complex(query):
routing_reason = "semantic_complexity"
response = system2_pipe(query, max_new_tokens=150)
used_model = "system2"
# Step 2: Simple query path with entropy fallback
else:
# Format for TinyLlama chat template
formatted_query = f"<|system|>\nYou are a helpful assistant.</s>\n<|user|>\n{query}</s>\n<|assistant|>\n"
# Generate with entropy tracking
response = system1_pipe(
formatted_query,
max_new_tokens=100,
return_full_text=False,
pad_token_id=tokenizer1.eos_token_id
)
# Extract text response
if isinstance(response, list) and len(response) > 0:
generated_text = response[0]['generated_text'].strip()
# Remove the prompt part if it's included
result_text = generated_text.replace(formatted_query, "").strip()
else:
result_text = str(response).strip()
used_model = "system1"
routing_reason = "simple_query"
# If we didn't get result_text from the simple path
if not result_text and used_model == "system2":
if isinstance(response, list) and len(response) > 0:
result_text = response[0]['generated_text'].replace(query, "", 1).strip()
# Tracing
trace_data = {
"timestamp": datetime.utcnow().isoformat(),
"query": query,
"model_used": used_model,
"routing_reason": routing_reason,
"entropy": entropy,
"response": result_text
}
logging.info(json.dumps(trace_data))
# Calculate latency
latency_ms = (time.time() - start_time) * 1000
return RouterResponse(
response=result_text,
model_used=used_model,
routing_reason=routing_reason,
latency_ms=round(latency_ms, 2),
entropy=round(entropy, 2)
)
except Exception as e:
error_msg = f"Processing error for query '{query[:20]}...': {str(e)}"
logging.error(error_msg)
# Fallback to simple response
return RouterResponse(
response="I apologize, but I encountered an error processing your request. Please try again with a simpler query.",
model_used="error_fallback",
routing_reason="error_recovery",
latency_ms=round((time.time() - start_time) * 1000, 2),
entropy=0.0
)
# Health check endpoint
@app.get("/health")
async def health_check():
gpu_memory = 0
if torch.cuda.is_available():
gpu_memory = torch.cuda.get_device_properties(0).total_memory / (1024 ** 3)
return {
"status": "healthy",
"system1": SYSTEM1_MODEL,
"system2": SYSTEM2_MODEL,
"device": "cuda" if torch.cuda.is_available() else "cpu",
"gpu_memory_gb": round(gpu_memory, 2),
"cpu_memory_gb": round(psutil.virtual_memory().total / (1024 ** 3), 2)
}
# Warmup endpoint to prepare models
@app.post("/warmup")
async def warmup_models():
try:
# Warm up system 1
system1_pipe("Hello, how are you?", max_new_tokens=10)
# Warm up system 2
system2_pipe("What is the capital of France?", max_new_tokens=10)
return {"status": "models warmed up successfully"}
except Exception as e:
return {"status": "warmup failed", "error": str(e)}
if __name__ == "__main__":
import uvicorn
print("Starting server with memory-optimized configuration...")
uvicorn.run(app, host="0.0.0.0", port=8000, log_level="info")