File size: 2,137 Bytes
36fcbf8
 
 
 
 
 
 
 
 
 
 
 
1aecc68
8b44be5
1aecc68
 
 
2a7a917
1aecc68
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36fcbf8
 
 
 
 
 
 
 
1aecc68
 
 
 
36fcbf8
 
 
 
 
 
 
 
1aecc68
36fcbf8
 
 
 
 
 
 
 
 
 
 
 
 
 
2a7a917
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import os
import torch
from fastapi import FastAPI, UploadFile, File, Depends
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModel
from PIL import Image
from typing import Optional
import io

app = FastAPI()

model_id = "OpenGVLab/InternVL2_5-2B"
device = "cpu"

# Global variables for the model, initially empty
model = None
tokenizer = None

def load_model_if_needed():
    global model, tokenizer
    if model is None:
        print("Loading model... this may take a moment.")
        
        # Load Tokenizer
        tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
        
        # Load Model
        # We use bfloat16 to save memory (5GB vs 10GB)
        # We use low_cpu_mem_usage to load efficiently
        model = AutoModel.from_pretrained(
            model_id, 
            trust_remote_code=True, 
            torch_dtype=torch.bfloat16, 
            low_cpu_mem_usage=True
        ).to(device).eval()
        
        print("Model loaded successfully!")

class GenerateRequest(BaseModel):
    text_input: Optional[str] = None
    max_new_tokens: int = 1024
    do_sample: bool = False

@app.post("/generate")
async def generate(image: UploadFile = File(...), request: GenerateRequest = Depends()):
    # 1. Trigger the load on the first request
    load_model_if_needed()

    # Read image
    image_bytes = await image.read()
    pil_image = Image.open(io.BytesIO(image_bytes)).convert("RGB")

    if request.text_input:
        prompt = f"<s><image>\n{request.text_input}</s>"
    else:
        prompt = f"<s><image>\nDescribe the image.</s>"

    inputs = tokenizer(prompt, pil_image, return_tensors="pt").to(device)

    generation_args = {
        "max_new_tokens": request.max_new_tokens,
        "do_sample": request.do_sample,
    }

    with torch.no_grad():
        output_ids = model.generate(**inputs, **generation_args)

    response = tokenizer.decode(output_ids[0], skip_special_tokens=True)
    return {"generated_text": response}

@app.get("/")
async def read_root():
    return {"message": "InternVL2_5-2B API. Go to /docs for API documentation."}