| | from fastapi import FastAPI, HTTPException |
| | import numpy as np |
| | import torch |
| | from pydantic import BaseModel |
| | from typing import List |
| | import base64 |
| | import io |
| | import os |
| | import logging |
| | from pathlib import Path |
| | from inference import InferenceRecipe |
| | from fastapi.middleware.cors import CORSMiddleware |
| |
|
| | from omegaconf import OmegaConf, DictConfig |
| |
|
| | logging.basicConfig(level=logging.INFO) |
| | logger = logging.getLogger(__name__) |
| |
|
| | app = FastAPI() |
| |
|
| | |
| | app.add_middleware( |
| | CORSMiddleware, |
| | allow_origins=["*"], |
| | allow_credentials=True, |
| | allow_methods=["*"], |
| | allow_headers=["*"], |
| | ) |
| |
|
| | class EmbeddingRequest(BaseModel): |
| | embedding: List[float] |
| |
|
| | class TextResponse(BaseModel): |
| | texts: List[str] = [] |
| |
|
| | |
| | INITIALIZATION_STATUS = { |
| | "model_loaded": False, |
| | "error": None |
| | } |
| |
|
| | |
| | inference_recipe = None |
| | cfg = None |
| |
|
| |
|
| | def initialize_model(): |
| | """Initialize the model with correct path resolution""" |
| | global inference_recipe, INITIALIZATION_STATUS, cfg |
| | try: |
| | device = "cuda" if torch.cuda.is_available() else "cpu" |
| | logger.info(f"Initializing model on device: {device}") |
| | |
| | |
| | model_path = os.path.abspath(os.path.join('/app', 'models')) |
| | logger.info(f"Loading models from: {model_path}") |
| | |
| | if not os.path.exists(model_path): |
| | raise RuntimeError(f"Model path {model_path} does not exist") |
| | |
| | |
| | model_files = os.listdir(model_path) |
| | logger.info(f"Available model files: {model_files}") |
| | |
| | cfg = OmegaConf.load(os.path.join('/app', 'training_config.yml')) |
| | cfg.model = DictConfig({ |
| | "_component_": "models.mmllama3_8b", |
| | "use_clip": False, |
| | "perception_tokens": cfg.model.perception_tokens, |
| | }) |
| | cfg.checkpointer.checkpoint_dir = model_path |
| | cfg.checkpointer.checkpoint_files = ["meta_model_5.pt"] |
| | cfg.inference.max_new_tokens = 300 |
| | cfg.tokenizer.path = os.path.join(model_path, "tokenizer.model") |
| | inference_recipe = InferenceRecipe(cfg) |
| | inference_recipe.setup(cfg=cfg) |
| | INITIALIZATION_STATUS["model_loaded"] = True |
| | logger.info("Model initialized successfully") |
| | return True |
| | except Exception as e: |
| | INITIALIZATION_STATUS["error"] = str(e) |
| | logger.error(f"Failed to initialize model: {e}") |
| | return False |
| | |
| | @app.on_event("startup") |
| | async def startup_event(): |
| | """Initialize model on startup""" |
| | initialize_model() |
| |
|
| | @app.get("/api/v1/health") |
| | def health_check(): |
| | """Health check endpoint""" |
| | status = { |
| | "status": "healthy" if INITIALIZATION_STATUS["model_loaded"] else "initializing", |
| | "initialization_status": INITIALIZATION_STATUS |
| | } |
| | |
| | if inference_recipe is not None: |
| | status.update({ |
| | "device": str(inference_recipe._device), |
| | "dtype": str(inference_recipe._dtype) |
| | }) |
| | |
| | return status |
| |
|
| | @app.post("/api/v1/inference") |
| | async def inference(request: EmbeddingRequest) -> TextResponse: |
| | """Run inference with enhanced error handling and logging""" |
| | if not INITIALIZATION_STATUS["model_loaded"]: |
| | raise HTTPException( |
| | status_code=503, |
| | detail=f"Model not ready. Status: {INITIALIZATION_STATUS}" |
| | ) |
| | |
| | try: |
| | |
| | logger.info("Received inference request") |
| | |
| | |
| | embedding = request.embedding |
| | embedding = torch.tensor(embedding) |
| | embedding = embedding.unsqueeze(0) |
| | embedding = embedding.reshape(-1, 1024) |
| | logger.info(f"Converted embedding to tensor with shape: {embedding.shape}") |
| | |
| | |
| | results = inference_recipe.generate_batch(cfg=cfg, video_ib_embed=embedding) |
| | logger.info("Generation complete") |
| | |
| | |
| | if isinstance(results, str): |
| | results = [results] |
| | |
| | return TextResponse(texts=results) |
| | |
| | except Exception as e: |
| | logger.error(f"Inference failed: {str(e)}", exc_info=True) |
| | raise HTTPException( |
| | status_code=500, |
| | detail=str(e) |
| | ) |
| | |
| | if __name__ == "__main__": |
| | import uvicorn |
| | uvicorn.run(app, host="0.0.0.0", port=8000) |
| |
|
| |
|