import os import torch from fastapi import FastAPI, UploadFile, File, Depends from pydantic import BaseModel from transformers import AutoTokenizer, AutoModel from PIL import Image from typing import Optional import io app = FastAPI() model_id = "OpenGVLab/InternVL2_5-2B" device = "cpu" # Global variables for the model, initially empty model = None tokenizer = None def load_model_if_needed(): global model, tokenizer if model is None: print("Loading model... this may take a moment.") # Load Tokenizer tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True) # Load Model # We use bfloat16 to save memory (5GB vs 10GB) # We use low_cpu_mem_usage to load efficiently model = AutoModel.from_pretrained( model_id, trust_remote_code=True, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True ).to(device).eval() print("Model loaded successfully!") class GenerateRequest(BaseModel): text_input: Optional[str] = None max_new_tokens: int = 1024 do_sample: bool = False @app.post("/generate") async def generate(image: UploadFile = File(...), request: GenerateRequest = Depends()): # 1. Trigger the load on the first request load_model_if_needed() # Read image image_bytes = await image.read() pil_image = Image.open(io.BytesIO(image_bytes)).convert("RGB") if request.text_input: prompt = f"\n{request.text_input}" else: prompt = f"\nDescribe the image." inputs = tokenizer(prompt, pil_image, return_tensors="pt").to(device) generation_args = { "max_new_tokens": request.max_new_tokens, "do_sample": request.do_sample, } with torch.no_grad(): output_ids = model.generate(**inputs, **generation_args) response = tokenizer.decode(output_ids[0], skip_special_tokens=True) return {"generated_text": response} @app.get("/") async def read_root(): return {"message": "InternVL2_5-2B API. Go to /docs for API documentation."}