import logging import os import modal from fastapi import Header from models import MODEL_IDS logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) CACHE_DIR = "/cache" image = ( modal.Image.debian_slim(python_version="3.12") .pip_install("torch", "transformers", "accelerate", "fastapi", "bitsandbytes") .add_local_dir("site", "/root") ) app = modal.App("posttraining-chat", image=image) cache_vol = modal.Volume.from_name("hf-cache", create_if_missing=True) @app.cls( gpu="T4", scaledown_window=60, secrets=[modal.Secret.from_dotenv()], volumes={CACHE_DIR: cache_vol}, ) class Inference: @modal.enter() def setup(self): os.environ["HF_HOME"] = CACHE_DIR self.models = {} def load_model(self, model_id: str): if model_id in self.models: logger.info(f"Model already loaded: {model_id}") return import torch from transformers import AutoModelForCausalLM, AutoTokenizer logger.info(f"Loading model: {model_id}") try: tokenizer = AutoTokenizer.from_pretrained(model_id) logger.info(f"Tokenizer loaded for {model_id}") model = AutoModelForCausalLM.from_pretrained( model_id, torch_dtype=torch.float16, device_map="auto", ) logger.info(f"Model loaded successfully: {model_id}") self.models[model_id] = {"model": model, "tokenizer": tokenizer} cache_vol.commit() except Exception as e: logger.error(f"Failed to load model {model_id}: {e}") raise @modal.fastapi_endpoint(method="POST") def generate(self, request: dict, x_api_key: str | None = Header(None)) -> dict: import torch logger.info( f"Received request: model_id={request.get('model_id')}, " f"message_len={len(request.get('message', ''))}, " f"history_len={len(request.get('history', []))}" f"message: {request.get('message', '')}..." ) expected_key = os.environ.get("MODEL_SITE_API_KEY") if not expected_key or x_api_key != expected_key: logger.warning("Auth failed: invalid or missing API key") return {"error": "Unauthorized - invalid API key"} model_id = request.get("model_id", MODEL_IDS[0]) message = request.get("message", "") history = request.get("history", []) if model_id not in MODEL_IDS: logger.warning(f"Model not found: {model_id}") return {"error": f"Model {model_id} not found"} try: self.load_model(model_id) except Exception as e: logger.error(f"Model loading failed: {e}") return {"error": f"Failed to load model: {e}"} tokenizer = self.models[model_id]["tokenizer"] model = self.models[model_id]["model"] messages = [] for msg in history: role = msg.get("role", "user") content = msg.get("content", "") messages.append({"role": role, "content": content}) messages.append({"role": "user", "content": message}) conversation = tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True, ) try: inputs = tokenizer(conversation, return_tensors="pt").to("cuda") logger.info(f"Tokenized input shape: {inputs['input_ids'].shape}") with torch.no_grad(): outputs = model.generate( **inputs, max_new_tokens=1024, do_sample=True, temperature=0.4, top_p=0.85, repetition_penalty=1.15, pad_token_id=tokenizer.eos_token_id, ) logger.info(f"Generated output shape: {outputs.shape}") # Extract only the newly generated tokens (skip the input) new_tokens = outputs[0][inputs["input_ids"].shape[1] :] response = tokenizer.decode(new_tokens, skip_special_tokens=True).strip() logger.info(f"Final response length: {len(response)}") logger.info(f"Response: {response}") return {"response": response} except Exception as e: logger.error(f"Inference failed: {e}", exc_info=True) return {"error": f"Inference failed: {e}"}