|
|
from fastapi import FastAPI, UploadFile, Form |
|
|
from fastapi.responses import JSONResponse |
|
|
from transformers import AutoProcessor, AutoModelForVisualQuestionAnswering |
|
|
from PIL import Image |
|
|
import torch |
|
|
import uvicorn |
|
|
|
|
|
|
|
|
app = FastAPI() |
|
|
|
|
|
|
|
|
processor = AutoProcessor.from_pretrained("Sanket17/hello", trust_remote_code=True) |
|
|
model = AutoModelForVisualQuestionAnswering.from_pretrained("Sanket17/hello", trust_remote_code=True) |
|
|
|
|
|
@app.post("/vqa/") |
|
|
async def visual_question_answer(file: UploadFile, question: str = Form(...)): |
|
|
""" |
|
|
Endpoint for visual question answering. |
|
|
- file: Upload an image file |
|
|
- question: Textual question about the image |
|
|
""" |
|
|
try: |
|
|
|
|
|
image = Image.open(file.file).convert("RGB") |
|
|
|
|
|
|
|
|
inputs = processor(images=image, text=question, return_tensors="pt") |
|
|
|
|
|
|
|
|
outputs = model(**inputs) |
|
|
|
|
|
|
|
|
answer = outputs.logits.argmax(dim=-1).item() |
|
|
|
|
|
|
|
|
answer_str = processor.decode([answer]) |
|
|
|
|
|
|
|
|
return JSONResponse(content={"question": question, "answer": answer_str}) |
|
|
|
|
|
except Exception as e: |
|
|
return JSONResponse(content={"error": str(e)}, status_code=500) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
uvicorn.run(app, host="0.0.0.0", port=8000) |
|
|
|