Spaces:
Runtime error
Runtime error
| import os | |
| import torch | |
| from fastapi import FastAPI, UploadFile, File, Depends | |
| from pydantic import BaseModel | |
| from transformers import AutoTokenizer, AutoModel | |
| from PIL import Image | |
| from typing import Optional | |
| import io | |
| app = FastAPI() | |
| model_id = "OpenGVLab/InternVL2_5-2B" | |
| device = "cpu" | |
| # Global variables for the model, initially empty | |
| model = None | |
| tokenizer = None | |
| def load_model_if_needed(): | |
| global model, tokenizer | |
| if model is None: | |
| print("Loading model... this may take a moment.") | |
| # Load Tokenizer | |
| tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True) | |
| # Load Model | |
| # We use bfloat16 to save memory (5GB vs 10GB) | |
| # We use low_cpu_mem_usage to load efficiently | |
| model = AutoModel.from_pretrained( | |
| model_id, | |
| trust_remote_code=True, | |
| torch_dtype=torch.bfloat16, | |
| low_cpu_mem_usage=True | |
| ).to(device).eval() | |
| print("Model loaded successfully!") | |
| class GenerateRequest(BaseModel): | |
| text_input: Optional[str] = None | |
| max_new_tokens: int = 1024 | |
| do_sample: bool = False | |
| async def generate(image: UploadFile = File(...), request: GenerateRequest = Depends()): | |
| # 1. Trigger the load on the first request | |
| load_model_if_needed() | |
| # Read image | |
| image_bytes = await image.read() | |
| pil_image = Image.open(io.BytesIO(image_bytes)).convert("RGB") | |
| if request.text_input: | |
| prompt = f"<s><image>\n{request.text_input}</s>" | |
| else: | |
| prompt = f"<s><image>\nDescribe the image.</s>" | |
| inputs = tokenizer(prompt, pil_image, return_tensors="pt").to(device) | |
| generation_args = { | |
| "max_new_tokens": request.max_new_tokens, | |
| "do_sample": request.do_sample, | |
| } | |
| with torch.no_grad(): | |
| output_ids = model.generate(**inputs, **generation_args) | |
| response = tokenizer.decode(output_ids[0], skip_special_tokens=True) | |
| return {"generated_text": response} | |
| async def read_root(): | |
| return {"message": "InternVL2_5-2B API. Go to /docs for API documentation."} |