new endpoint for mobile app that supports base64 encoding
Browse files- base64_test.ipynb +0 -0
- main.py +13 -13
base64_test.ipynb
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
main.py
CHANGED
|
@@ -5,11 +5,12 @@ from PIL import Image
|
|
| 5 |
from transformers import LiltForTokenClassification, AutoTokenizer
|
| 6 |
import token_classification
|
| 7 |
import torch
|
| 8 |
-
from fastapi import FastAPI, UploadFile
|
| 9 |
from contextlib import asynccontextmanager
|
| 10 |
import json
|
| 11 |
import io
|
| 12 |
from models import LiLTRobertaLikeForRelationExtraction
|
|
|
|
| 13 |
config = {}
|
| 14 |
|
| 15 |
@asynccontextmanager
|
|
@@ -30,12 +31,20 @@ app = FastAPI(lifespan=lifespan)
|
|
| 30 |
|
| 31 |
@app.post("/submit-doc")
|
| 32 |
async def ProcessDocument(file: UploadFile):
|
| 33 |
-
|
|
|
|
| 34 |
reOutput = ExtractRelations(tokenClassificationOutput, ocr_df, img_size)
|
| 35 |
return reOutput
|
| 36 |
|
| 37 |
-
|
| 38 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 39 |
image = Image.open(io.BytesIO(content))
|
| 40 |
ocr_df = config['vision_client'].ocr(content, image)
|
| 41 |
input_ids, attention_mask, token_type_ids, bbox, token_actual_boxes, offset_mapping = config['processor'].process(ocr_df, image = image)
|
|
@@ -74,14 +83,5 @@ def ExtractRelations(tokenClassificationOutput, ocr_df, img_size):
|
|
| 74 |
question = config['tokenizer'].decode(input_ids[0][head_start:head_end])
|
| 75 |
answer = config['tokenizer'].decode(input_ids[0][tail_start:tail_end])
|
| 76 |
decoded_pred_relations.append((question, answer))
|
| 77 |
-
# print("Question:", question)
|
| 78 |
-
# print("Answer:", answer)
|
| 79 |
-
## This prints bboxes of each question and answer
|
| 80 |
-
# for item in merged_words:
|
| 81 |
-
# if item['text'] == question:
|
| 82 |
-
# print('Question', item['box'])
|
| 83 |
-
# if item['text'] == answer:
|
| 84 |
-
# print('Answer', item['box'])
|
| 85 |
-
# print("----------")
|
| 86 |
|
| 87 |
return {"pred_relations":json.dumps(outputs.pred_relations[0]), "entities":json.dumps(entities), "input_ids": json.dumps(input_ids.tolist()), "bboxes": json.dumps(bbox_org.tolist()),"token_labels":json.dumps(token_labels), "decoded_entities": json.dumps(decoded_entities), "decoded_pred_relations":json.dumps(decoded_pred_relations)}
|
|
|
|
| 5 |
from transformers import LiltForTokenClassification, AutoTokenizer
|
| 6 |
import token_classification
|
| 7 |
import torch
|
| 8 |
+
from fastapi import FastAPI, UploadFile, Form
|
| 9 |
from contextlib import asynccontextmanager
|
| 10 |
import json
|
| 11 |
import io
|
| 12 |
from models import LiLTRobertaLikeForRelationExtraction
|
| 13 |
+
from base64 import b64decode
|
| 14 |
config = {}
|
| 15 |
|
| 16 |
@asynccontextmanager
|
|
|
|
| 31 |
|
| 32 |
@app.post("/submit-doc")
|
| 33 |
async def ProcessDocument(file: UploadFile):
|
| 34 |
+
content = await file.read()
|
| 35 |
+
tokenClassificationOutput, ocr_df, img_size = LabelTokens(content)
|
| 36 |
reOutput = ExtractRelations(tokenClassificationOutput, ocr_df, img_size)
|
| 37 |
return reOutput
|
| 38 |
|
| 39 |
+
@app.post("/submit-doc-mobile")
|
| 40 |
+
async def ProcessDocument(base64str: str = Form(...)):
|
| 41 |
+
str_as_bytes = str.encode(base64str)
|
| 42 |
+
content = b64decode(str_as_bytes)
|
| 43 |
+
tokenClassificationOutput, ocr_df, img_size = LabelTokens(content)
|
| 44 |
+
reOutput = ExtractRelations(tokenClassificationOutput, ocr_df, img_size)
|
| 45 |
+
return reOutput
|
| 46 |
+
|
| 47 |
+
def LabelTokens(content):
|
| 48 |
image = Image.open(io.BytesIO(content))
|
| 49 |
ocr_df = config['vision_client'].ocr(content, image)
|
| 50 |
input_ids, attention_mask, token_type_ids, bbox, token_actual_boxes, offset_mapping = config['processor'].process(ocr_df, image = image)
|
|
|
|
| 83 |
question = config['tokenizer'].decode(input_ids[0][head_start:head_end])
|
| 84 |
answer = config['tokenizer'].decode(input_ids[0][tail_start:tail_end])
|
| 85 |
decoded_pred_relations.append((question, answer))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 86 |
|
| 87 |
return {"pred_relations":json.dumps(outputs.pred_relations[0]), "entities":json.dumps(entities), "input_ids": json.dumps(input_ids.tolist()), "bboxes": json.dumps(bbox_org.tolist()),"token_labels":json.dumps(token_labels), "decoded_entities": json.dumps(decoded_entities), "decoded_pred_relations":json.dumps(decoded_pred_relations)}
|