|
|
from fastapi import FastAPI, File, UploadFile
|
|
|
from fastapi.responses import JSONResponse, RedirectResponse
|
|
|
from transformers import ViltProcessor, ViltForQuestionAnswering
|
|
|
from PIL import Image
|
|
|
import requests
|
|
|
import io
|
|
|
|
|
|
app = FastAPI(title="Visual Question and Answering API", version="0.0.1")
|
|
|
|
|
|
|
|
|
processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
|
|
|
model = ViltForQuestionAnswering.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
|
|
|
|
|
|
def get_answer(image, text):
|
|
|
try:
|
|
|
|
|
|
img = Image.open(io.BytesIO(image)).convert("RGB")
|
|
|
|
|
|
|
|
|
encoding = processor(img, text, return_tensors="pt")
|
|
|
|
|
|
|
|
|
outputs = model(**encoding)
|
|
|
logits = outputs.logits
|
|
|
idx = logits.argmax(-1).item()
|
|
|
answer = model.config.id2label[idx]
|
|
|
|
|
|
return answer
|
|
|
|
|
|
except Exception as e:
|
|
|
return str(e)
|
|
|
|
|
|
@app.get("/", include_in_schema=False)
|
|
|
async def index():
|
|
|
return RedirectResponse(url="/docs")
|
|
|
|
|
|
@app.post("/answer")
|
|
|
async def process_image(image: UploadFile = File(...), text: str = None):
|
|
|
try:
|
|
|
answer = get_answer(await image.read(), text)
|
|
|
return JSONResponse({"Answer": answer})
|
|
|
|
|
|
except Exception as e:
|
|
|
return JSONResponse({"Sorry, please reach out to the Admin!": str(e)})
|
|
|
|