Sanket17 commited on
Commit
107c705
·
1 Parent(s): aac46a1

updaded files

Browse files
Files changed (3) hide show
  1. src/api.py +17 -0
  2. src/model.py +6 -0
  3. src/schemas.py +5 -0
src/api.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI
2
+ from .model import load_model
3
+ from .schemas import Query
4
+
5
+ app = FastAPI()
6
+ processor, model = load_model()
7
+
8
+ @app.post("/predict")
9
+ async def predict(query: Query):
10
+ inputs = processor(images=query.image_url, text=query.question, return_tensors="pt")
11
+ outputs = model(**inputs)
12
+ predicted_answer = processor.decode(outputs.logits.argmax(-1)[0], skip_special_tokens=True)
13
+ return {"answer": predicted_answer}
14
+
15
+ @app.get("/")
16
+ async def root():
17
+ return {"message": "OmniParser API is running. Use /predict endpoint for predictions."}
src/model.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ from transformers import AutoProcessor, AutoModelForVisualQuestionAnswering
2
+
3
+ def load_model():
4
+ processor = AutoProcessor.from_pretrained("microsoft/OmniParser")
5
+ model = AutoModelForVisualQuestionAnswering.from_pretrained("microsoft/OmniParser")
6
+ return processor, model
src/schemas.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from pydantic import BaseModel
2
+
3
+ class Query(BaseModel):
4
+ image_url: str
5
+ question: str