MultiModalModel / app.py
marveljo's picture
Update app.py
8f0412a verified
raw
history blame
1.92 kB
from fastapi import FastAPI, UploadFile, Form
from fastapi.responses import JSONResponse
from transformers import AutoProcessor, AutoModelForVision2Seq
import torch
from PIL import Image
import io
app = FastAPI(title="Aloe Vision Backend")
print("πŸš€ Loading model, please wait...")
model_id = "HPAI-BSC/Aloe-Vision-7B-AR"
processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
model = AutoModelForVision2Seq.from_pretrained(
model_id,
torch_dtype=torch.bfloat16,
device_map="auto",
trust_remote_code=True,
)
print("βœ… Model loaded!")
@app.post("/analyze")
async def analyze(file: UploadFile, prompt: str = Form("Describe the image")):
"""Receive an image + prompt, return model output."""
try:
image_bytes = await file.read()
image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
messages = [
{
"role": "user",
"content": [
{"type": "image", "image": image},
{"type": "text", "text": prompt}
]
}
]
text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
image_inputs = processor.process_vision_info(messages)
inputs = processor(
text=[text],
**image_inputs,
return_tensors="pt"
).to(model.device)
generated = model.generate(
**inputs,
max_new_tokens=256,
do_sample=False,
eos_token_id=processor.tokenizer.eos_token_id,
)
output_text = processor.batch_decode(generated, skip_special_tokens=True)[0]
return {"result": output_text}
except Exception as e:
return JSONResponse(content={"error": str(e)}, status_code=500)
@app.get("/")
def root():
return {"status": "ok", "message": "Aloe Vision Backend running!"}