import os
import torch
from fastapi import FastAPI, UploadFile, File, Depends
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModel
from PIL import Image
from typing import Optional
import io
app = FastAPI()
model_id = "OpenGVLab/InternVL2_5-2B"
device = "cpu"
# Global variables for the model, initially empty
model = None
tokenizer = None
def load_model_if_needed():
global model, tokenizer
if model is None:
print("Loading model... this may take a moment.")
# Load Tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
# Load Model
# We use bfloat16 to save memory (5GB vs 10GB)
# We use low_cpu_mem_usage to load efficiently
model = AutoModel.from_pretrained(
model_id,
trust_remote_code=True,
torch_dtype=torch.bfloat16,
low_cpu_mem_usage=True
).to(device).eval()
print("Model loaded successfully!")
class GenerateRequest(BaseModel):
text_input: Optional[str] = None
max_new_tokens: int = 1024
do_sample: bool = False
@app.post("/generate")
async def generate(image: UploadFile = File(...), request: GenerateRequest = Depends()):
# 1. Trigger the load on the first request
load_model_if_needed()
# Read image
image_bytes = await image.read()
pil_image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
if request.text_input:
prompt = f"\n{request.text_input}"
else:
prompt = f"\nDescribe the image."
inputs = tokenizer(prompt, pil_image, return_tensors="pt").to(device)
generation_args = {
"max_new_tokens": request.max_new_tokens,
"do_sample": request.do_sample,
}
with torch.no_grad():
output_ids = model.generate(**inputs, **generation_args)
response = tokenizer.decode(output_ids[0], skip_special_tokens=True)
return {"generated_text": response}
@app.get("/")
async def read_root():
return {"message": "InternVL2_5-2B API. Go to /docs for API documentation."}