omniparser / app.py
Sanket17's picture
Update app.py
26966db verified
from fastapi import FastAPI, UploadFile, Form
from fastapi.responses import JSONResponse
from transformers import AutoProcessor, AutoModelForVisualQuestionAnswering
from PIL import Image
import torch
import uvicorn
import os
# Initialize FastAPI app
app = FastAPI()
# Access the Hugging Face token from the secret section
hf_token = os.getenv("HP_token")
# Load model and processor with the token
processor = AutoProcessor.from_pretrained("Sanket17/hello", use_auth_token=hf_token)
model = AutoModelForVisualQuestionAnswering.from_pretrained("Sanket17/hello", use_auth_token=hf_token)
@app.post("/vqa/")
async def visual_question_answer(file: UploadFile, question: str = Form(...)):
"""
Endpoint for visual question answering.
- file: Upload an image file
- question: Textual question about the image
"""
try:
# Load image
image = Image.open(file.file).convert("RGB")
# Preprocess inputs
inputs = processor(images=image, text=question, return_tensors="pt")
# Get model predictions
outputs = model(**inputs)
# Decode the answer (check model output for correct handling)
answer = outputs.logits.argmax(dim=-1).item() # Example way to get the answer index
# If the output logits contain a mapping, we can return the answer string
answer_str = processor.decode([answer]) # Assuming you get the answer index from logits
# Return JSON response
return JSONResponse(content={"question": question, "answer": answer_str})
except Exception as e:
return JSONResponse(content={"error": str(e)}, status_code=500)
# Start the FastAPI server
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=8000)