ImageTest / api.py
NishantD's picture
Upload 5 files
aca6900 verified
from fastapi import FastAPI, File, UploadFile
from fastapi.responses import JSONResponse, RedirectResponse
from transformers import ViltProcessor, ViltForQuestionAnswering
from PIL import Image
import requests
import io
app = FastAPI(title="Visual Question and Answering API", version="0.0.1")
#Loading the model and tokenizer
processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
model = ViltForQuestionAnswering.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
def get_answer(image, text):
try:
# Load and process the image
img = Image.open(io.BytesIO(image)).convert("RGB")
# Prepare inputs
encoding = processor(img, text, return_tensors="pt")
# Forward pass
outputs = model(**encoding)
logits = outputs.logits
idx = logits.argmax(-1).item()
answer = model.config.id2label[idx]
return answer
except Exception as e:
return str(e)
@app.get("/", include_in_schema=False)
async def index():
return RedirectResponse(url="/docs")
@app.post("/answer")
async def process_image(image: UploadFile = File(...), text: str = None):
try:
answer = get_answer(await image.read(), text)
return JSONResponse({"Answer": answer})
except Exception as e:
return JSONResponse({"Sorry, please reach out to the Admin!": str(e)})