MultiModalModel / app.py
marveljo's picture
Update app.py
039ebfd verified
raw
history blame
2.43 kB
import torch
from transformers import AutoProcessor, AutoModelForVision2Seq
from fastapi import FastAPI, File, UploadFile, Form
from fastapi.responses import JSONResponse
from io import BytesIO
from PIL import Image
model_id = "HPAI-BSC/Aloe-Vision-7B-AR"
processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
model = AutoModelForVision2Seq.from_pretrained(
model_id,
torch_dtype=torch.bfloat16,
device_map="auto",
trust_remote_code=True,
)
app = FastAPI(title="Aloe Vision 7B AR API")
@app.post("/predict")
async def predict(
file: UploadFile = File(None),
question: str = Form(None)
):
try:
# --- Case 1: both image and text ---
if file and question:
image = Image.open(BytesIO(await file.read())).convert("RGB")
messages = [
{"role": "user", "content": [
{"type": "image", "image": image},
{"type": "text", "text": question}
]}
]
# --- Case 2: text only ---
elif question and not file:
messages = [{"role": "user", "content": [{"type": "text", "text": question}]}]
# --- Case 3: image only ---
elif file and not question:
image = Image.open(BytesIO(await file.read())).convert("RGB")
messages = [
{"role": "user", "content": [
{"type": "image", "image": image},
{"type": "text", "text": "Describe this image briefly."}
]}
]
else:
return JSONResponse({"error": "You must provide an image, text, or both."}, status_code=400)
# --- Process ---
text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
image_inputs = processor.process_vision_info(messages)
inputs = processor(text=[text], **image_inputs, return_tensors="pt").to(model.device)
generated = model.generate(
**inputs,
max_new_tokens=256,
do_sample=False,
eos_token_id=processor.tokenizer.eos_token_id,
)
output_text = processor.batch_decode(generated, skip_special_tokens=True)[0]
answer = output_text.split(text)[-1].strip()
return JSONResponse({"answer": answer})
except Exception as e:
return JSONResponse({"error": str(e)}, status_code=500)