| from fastapi import FastAPI, HTTPException |
| import numpy as np |
| import torch |
| from pydantic import BaseModel |
| from typing import List |
| import base64 |
| import io |
| import os |
| import logging |
| from pathlib import Path |
| from inference import InferenceRecipe |
| from fastapi.middleware.cors import CORSMiddleware |
|
|
| from omegaconf import OmegaConf, DictConfig |
|
|
| logging.basicConfig(level=logging.INFO) |
| logger = logging.getLogger(__name__) |
|
|
| app = FastAPI() |
|
|
| |
| app.add_middleware( |
| CORSMiddleware, |
| allow_origins=["*"], |
| allow_credentials=True, |
| allow_methods=["*"], |
| allow_headers=["*"], |
| ) |
|
|
| class EmbeddingRequest(BaseModel): |
| embedding: List[float] |
|
|
| class TextResponse(BaseModel): |
| texts: List[str] = [] |
|
|
| |
| INITIALIZATION_STATUS = { |
| "model_loaded": False, |
| "error": None |
| } |
|
|
| |
| inference_recipe = None |
| cfg = None |
|
|
|
|
| def initialize_model(): |
| """Initialize the model with correct path resolution""" |
| global inference_recipe, INITIALIZATION_STATUS, cfg |
| try: |
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| logger.info(f"Initializing model on device: {device}") |
| |
| |
| model_path = os.path.abspath(os.path.join('/app', 'models')) |
| logger.info(f"Loading models from: {model_path}") |
| |
| if not os.path.exists(model_path): |
| raise RuntimeError(f"Model path {model_path} does not exist") |
| |
| |
| model_files = os.listdir(model_path) |
| logger.info(f"Available model files: {model_files}") |
| |
| cfg = OmegaConf.load(os.path.join('/app', 'training_config.yml')) |
| cfg.model = DictConfig({ |
| "_component_": "models.mmllama3_8b", |
| "use_clip": False, |
| "perception_tokens": cfg.model.perception_tokens, |
| }) |
| cfg.checkpointer.checkpoint_dir = model_path |
| cfg.checkpointer.checkpoint_files = ["meta_model_5.pt"] |
| cfg.inference.max_new_tokens = 300 |
| cfg.tokenizer.path = os.path.join(model_path, "tokenizer.model") |
| inference_recipe = InferenceRecipe(cfg) |
| inference_recipe.setup(cfg=cfg) |
| INITIALIZATION_STATUS["model_loaded"] = True |
| logger.info("Model initialized successfully") |
| return True |
| except Exception as e: |
| INITIALIZATION_STATUS["error"] = str(e) |
| logger.error(f"Failed to initialize model: {e}") |
| return False |
| |
| @app.on_event("startup") |
| async def startup_event(): |
| """Initialize model on startup""" |
| initialize_model() |
|
|
| @app.get("/api/v1/health") |
| def health_check(): |
| """Health check endpoint""" |
| status = { |
| "status": "healthy" if INITIALIZATION_STATUS["model_loaded"] else "initializing", |
| "initialization_status": INITIALIZATION_STATUS |
| } |
| |
| if inference_recipe is not None: |
| status.update({ |
| "device": str(inference_recipe._device), |
| "dtype": str(inference_recipe._dtype) |
| }) |
| |
| return status |
|
|
| @app.post("/api/v1/inference") |
| async def inference(request: EmbeddingRequest) -> TextResponse: |
| """Run inference with enhanced error handling and logging""" |
| if not INITIALIZATION_STATUS["model_loaded"]: |
| raise HTTPException( |
| status_code=503, |
| detail=f"Model not ready. Status: {INITIALIZATION_STATUS}" |
| ) |
| |
| try: |
| |
| logger.info("Received inference request") |
| |
| |
| embedding = request.embedding |
| embedding = torch.tensor(embedding) |
| embedding = embedding.unsqueeze(0) |
| embedding = embedding.reshape(-1, 1024) |
| logger.info(f"Converted embedding to tensor with shape: {embedding.shape}") |
| |
| |
| results = inference_recipe.generate_batch(cfg=cfg, video_ib_embed=embedding) |
| logger.info("Generation complete") |
| |
| |
| if isinstance(results, str): |
| results = [results] |
| |
| return TextResponse(texts=results) |
| |
| except Exception as e: |
| logger.error(f"Inference failed: {str(e)}", exc_info=True) |
| raise HTTPException( |
| status_code=500, |
| detail=str(e) |
| ) |
| |
| if __name__ == "__main__": |
| import uvicorn |
| uvicorn.run(app, host="0.0.0.0", port=8000) |
|
|
|
|