| |
|
| | import os
|
| | import time
|
| | import json
|
| | import torch
|
| | import logging
|
| | from fastapi import FastAPI, HTTPException
|
| | from pydantic import BaseModel
|
| | from transformers import (
|
| | AutoModelForCausalLM,
|
| | AutoTokenizer,
|
| | BitsAndBytesConfig,
|
| | pipeline,
|
| | AutoConfig
|
| | )
|
| | from datetime import datetime
|
| | import psutil
|
| | import math
|
| |
|
| |
|
| | logging.basicConfig(
|
| | filename='router_tracing.log',
|
| | level=logging.INFO,
|
| | format='%(asctime)s - %(levelname)s - %(message)s'
|
| | )
|
| |
|
| |
|
| | app = FastAPI(title="System1/System2 Router", version="1.0")
|
| |
|
| |
|
| | SYSTEM1_MODEL = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
|
| | SYSTEM2_MODEL = "HuggingFaceH4/zephyr-7b-beta"
|
| |
|
| | def estimate_model_memory(model_id, bits=4, is_system2=False):
|
| | """Estimate memory requirements for a model with given quantization"""
|
| | try:
|
| | config = AutoConfig.from_pretrained(model_id)
|
| | total_params = sum(p.numel() for p in config.to_dict().values() if isinstance(p, int))
|
| |
|
| |
|
| | if bits == 4:
|
| | bytes_per_param = 0.5
|
| | elif bits == 8:
|
| | bytes_per_param = 1
|
| | else:
|
| | bytes_per_param = 2
|
| |
|
| | base_memory = total_params * bytes_per_param
|
| |
|
| | total_memory = base_memory * 1.2
|
| |
|
| |
|
| | return total_memory / (1024 ** 3)
|
| | except:
|
| |
|
| | if is_system2:
|
| | return 6.0 if bits == 4 else 14.0
|
| | else:
|
| | return 1.2 if bits == 8 else 2.5
|
| |
|
| | def get_device_map(model_id, bits=4, is_system2=False):
|
| | """Create an appropriate device map based on available resources"""
|
| | cuda_available = torch.cuda.is_available()
|
| | total_memory = psutil.virtual_memory().total / (1024 ** 3)
|
| |
|
| |
|
| | estimated_size = estimate_model_memory(model_id, bits, is_system2)
|
| |
|
| |
|
| | logging.info(f"System memory: {total_memory:.2f}GB, Estimated model size: {estimated_size:.2f}GB")
|
| |
|
| | if not cuda_available:
|
| | logging.warning("No GPU detected - using CPU only")
|
| | return "cpu"
|
| |
|
| |
|
| | gpu_memory = torch.cuda.get_device_properties(0).total_memory / (1024 ** 3)
|
| |
|
| |
|
| | if estimated_size < gpu_memory * 0.8:
|
| | return "auto"
|
| |
|
| |
|
| | logging.warning(f"Model size ({estimated_size:.2f}GB) exceeds GPU capacity ({gpu_memory:.2f}GB). Using CPU offloading.")
|
| |
|
| | if bits == 4:
|
| |
|
| | return {
|
| | "": 0,
|
| | "model.layers.0": "cpu",
|
| | "model.layers.1": "cpu",
|
| | "model.norm": "cpu",
|
| | "lm_head": "cpu"
|
| | }
|
| | else:
|
| |
|
| | return "cpu"
|
| |
|
| |
|
| | def load_quantized_model(model_id, is_system2=False):
|
| | """Load 4-bit quantized model with proper memory management"""
|
| | try:
|
| | compute_dtype = torch.bfloat16 if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else torch.float16
|
| | device_map = get_device_map(model_id, bits=4 if is_system2 else 8, is_system2=is_system2)
|
| |
|
| |
|
| | if is_system2:
|
| | quant_config = BitsAndBytesConfig(
|
| | load_in_4bit=True,
|
| | bnb_4bit_quant_type="nf4",
|
| | bnb_4bit_compute_dtype=compute_dtype,
|
| | bnb_4bit_use_double_quant=True,
|
| | llm_int8_enable_fp32_cpu_offload=True
|
| | )
|
| | else:
|
| | quant_config = BitsAndBytesConfig(
|
| | load_in_8bit=True,
|
| | llm_int8_enable_fp32_cpu_offload=True,
|
| | llm_int8_threshold=6.0
|
| | )
|
| |
|
| |
|
| | tokenizer = AutoTokenizer.from_pretrained(model_id)
|
| | if not tokenizer.pad_token:
|
| | tokenizer.pad_token = tokenizer.eos_token
|
| |
|
| |
|
| | model = AutoModelForCausalLM.from_pretrained(
|
| | model_id,
|
| | quantization_config=quant_config,
|
| | device_map=device_map,
|
| | offload_folder="offload_folder",
|
| | trust_remote_code=True,
|
| | low_cpu_mem_usage=True
|
| | )
|
| |
|
| | logging.info(f"Successfully loaded {model_id} with device_map: {device_map}")
|
| | return tokenizer, model
|
| |
|
| | except Exception as e:
|
| | logging.error(f"Model load failed for {model_id}: {str(e)}")
|
| |
|
| | if "out of memory" in str(e).lower() or "oom" in str(e).lower():
|
| | logging.warning("GPU memory insufficient. Falling back to CPU loading.")
|
| | try:
|
| | tokenizer = AutoTokenizer.from_pretrained(model_id)
|
| | if not tokenizer.pad_token:
|
| | tokenizer.pad_token = tokenizer.eos_token
|
| |
|
| | model = AutoModelForCausalLM.from_pretrained(
|
| | model_id,
|
| | device_map="cpu",
|
| | trust_remote_code=True
|
| | )
|
| | logging.info(f"Fallback CPU loading succeeded for {model_id}")
|
| | return tokenizer, model
|
| | except Exception as cpu_e:
|
| | logging.error(f"CPU fallback also failed: {str(cpu_e)}")
|
| | raise
|
| | raise
|
| |
|
| |
|
| | print("Loading quantized models with memory optimization...")
|
| | tokenizer1, model1 = load_quantized_model(SYSTEM1_MODEL)
|
| | tokenizer2, model2 = load_quantized_model(SYSTEM2_MODEL, is_system2=True)
|
| | print("Models loaded successfully!")
|
| |
|
| |
|
| | system1_pipe = pipeline(
|
| | "text-generation",
|
| | model=model1,
|
| | tokenizer=tokenizer1,
|
| | max_new_tokens=128,
|
| | do_sample=True,
|
| | temperature=0.7,
|
| | pad_token_id=tokenizer1.eos_token_id,
|
| | device_map="auto"
|
| | )
|
| |
|
| | system2_pipe = pipeline(
|
| | "text-generation",
|
| | model=model2,
|
| | tokenizer=tokenizer2,
|
| | max_new_tokens=256,
|
| | do_sample=True,
|
| | temperature=0.8,
|
| | pad_token_id=tokenizer2.eos_token_id,
|
| | device_map="auto",
|
| | torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
|
| | )
|
| |
|
| |
|
| | COMPLEX_KEYWORDS = {
|
| | 'explain', 'why', 'how', 'compare', 'analyze', 'reason',
|
| | 'steps', 'detailed', 'difference', 'advantage', 'disadvantage',
|
| | 'calculate', 'derive', 'formula', 'math', 'equation'
|
| | }
|
| |
|
| | def is_semantically_complex(query: str) -> bool:
|
| | """Rule-based semantic complexity check"""
|
| | lower_query = query.lower()
|
| | tokens = lower_query.split()
|
| |
|
| |
|
| | if any(keyword in lower_query for keyword in COMPLEX_KEYWORDS):
|
| | return True
|
| |
|
| |
|
| | if len(tokens) > 15:
|
| | return True
|
| |
|
| |
|
| | if any(pattern in lower_query for pattern in ['vs', 'versus', 'pros and cons', 'advantages and disadvantages', 'compare and contrast']):
|
| | return True
|
| |
|
| | return False
|
| |
|
| | def calculate_entropy(response):
|
| | """Calculate average token entropy from generation scores"""
|
| | try:
|
| |
|
| | if isinstance(response, dict):
|
| | scores = response.get("scores", [])
|
| | elif isinstance(response, list) and len(response) > 0:
|
| | scores = response[0].get("scores", [])
|
| | else:
|
| | scores = []
|
| |
|
| | entropies = []
|
| |
|
| | for item in scores:
|
| |
|
| | if isinstance(item, tuple):
|
| | logits = item[0]
|
| | elif hasattr(item, 'logits'):
|
| | logits = item.logits
|
| | else:
|
| | logits = item
|
| |
|
| | if isinstance(logits, torch.Tensor):
|
| | probs = torch.softmax(logits, dim=-1)
|
| | entropy = -torch.sum(probs * torch.log(probs + 1e-10)).item()
|
| | entropies.append(entropy)
|
| |
|
| | return sum(entropies) / len(entropies) if entropies else 0
|
| | except Exception as e:
|
| | logging.warning(f"Entropy calculation failed: {str(e)}")
|
| | return 0
|
| |
|
| |
|
| | class QueryRequest(BaseModel):
|
| | text: str
|
| |
|
| | class RouterResponse(BaseModel):
|
| | response: str
|
| | model_used: str
|
| | routing_reason: str
|
| | latency_ms: float
|
| | entropy: float = None
|
| |
|
| |
|
| | @app.post("/query", response_model=RouterResponse)
|
| | async def route_query(request: QueryRequest):
|
| | start_time = time.time()
|
| | query = request.text.strip()
|
| |
|
| | if not query:
|
| | raise HTTPException(status_code=400, detail="Empty query")
|
| |
|
| | routing_reason = ""
|
| | entropy = 0.0
|
| | used_model = ""
|
| | result_text = ""
|
| |
|
| | try:
|
| |
|
| | if is_semantically_complex(query):
|
| | routing_reason = "semantic_complexity"
|
| | response = system2_pipe(query, max_new_tokens=150)
|
| | used_model = "system2"
|
| |
|
| |
|
| | else:
|
| |
|
| | formatted_query = f"<|system|>\nYou are a helpful assistant.</s>\n<|user|>\n{query}</s>\n<|assistant|>\n"
|
| |
|
| |
|
| | response = system1_pipe(
|
| | formatted_query,
|
| | max_new_tokens=100,
|
| | return_full_text=False,
|
| | pad_token_id=tokenizer1.eos_token_id
|
| | )
|
| |
|
| |
|
| | if isinstance(response, list) and len(response) > 0:
|
| | generated_text = response[0]['generated_text'].strip()
|
| |
|
| | result_text = generated_text.replace(formatted_query, "").strip()
|
| | else:
|
| | result_text = str(response).strip()
|
| |
|
| | used_model = "system1"
|
| | routing_reason = "simple_query"
|
| |
|
| |
|
| | if not result_text and used_model == "system2":
|
| | if isinstance(response, list) and len(response) > 0:
|
| | result_text = response[0]['generated_text'].replace(query, "", 1).strip()
|
| |
|
| |
|
| | trace_data = {
|
| | "timestamp": datetime.utcnow().isoformat(),
|
| | "query": query,
|
| | "model_used": used_model,
|
| | "routing_reason": routing_reason,
|
| | "entropy": entropy,
|
| | "response": result_text
|
| | }
|
| | logging.info(json.dumps(trace_data))
|
| |
|
| |
|
| | latency_ms = (time.time() - start_time) * 1000
|
| |
|
| | return RouterResponse(
|
| | response=result_text,
|
| | model_used=used_model,
|
| | routing_reason=routing_reason,
|
| | latency_ms=round(latency_ms, 2),
|
| | entropy=round(entropy, 2)
|
| | )
|
| |
|
| | except Exception as e:
|
| | error_msg = f"Processing error for query '{query[:20]}...': {str(e)}"
|
| | logging.error(error_msg)
|
| |
|
| | return RouterResponse(
|
| | response="I apologize, but I encountered an error processing your request. Please try again with a simpler query.",
|
| | model_used="error_fallback",
|
| | routing_reason="error_recovery",
|
| | latency_ms=round((time.time() - start_time) * 1000, 2),
|
| | entropy=0.0
|
| | )
|
| |
|
| |
|
| | @app.get("/health")
|
| | async def health_check():
|
| | gpu_memory = 0
|
| | if torch.cuda.is_available():
|
| | gpu_memory = torch.cuda.get_device_properties(0).total_memory / (1024 ** 3)
|
| |
|
| | return {
|
| | "status": "healthy",
|
| | "system1": SYSTEM1_MODEL,
|
| | "system2": SYSTEM2_MODEL,
|
| | "device": "cuda" if torch.cuda.is_available() else "cpu",
|
| | "gpu_memory_gb": round(gpu_memory, 2),
|
| | "cpu_memory_gb": round(psutil.virtual_memory().total / (1024 ** 3), 2)
|
| | }
|
| |
|
| |
|
| | @app.post("/warmup")
|
| | async def warmup_models():
|
| | try:
|
| |
|
| | system1_pipe("Hello, how are you?", max_new_tokens=10)
|
| |
|
| |
|
| | system2_pipe("What is the capital of France?", max_new_tokens=10)
|
| |
|
| | return {"status": "models warmed up successfully"}
|
| | except Exception as e:
|
| | return {"status": "warmup failed", "error": str(e)}
|
| |
|
| | if __name__ == "__main__":
|
| | import uvicorn
|
| | print("Starting server with memory-optimized configuration...")
|
| | uvicorn.run(app, host="0.0.0.0", port=8000, log_level="info") |