| | import cv2 |
| | import io |
| | import numpy as np |
| | from PIL import Image |
| |
|
| | import pytesseract |
| |
|
| | from fastapi import FastAPI, UploadFile, File |
| | from fastapi.middleware.cors import CORSMiddleware |
| |
|
| | from mltu.inferenceModel import OnnxInferenceModel |
| | from mltu.utils.text_utils import ctc_decoder |
| | from mltu.transformers import ImageResizer |
| | from mltu.configs import BaseModelConfigs |
| |
|
| | from textblob import TextBlob |
| | from happytransformer import HappyTextToText, TTSettings |
| |
|
| |
|
| | from transformers import AutoTokenizer, T5ForConditionalGeneration |
| | from pydantic import BaseModel |
| |
|
| | tokenizer = AutoTokenizer.from_pretrained("grammarly/coedit-large") |
| | chatModel = T5ForConditionalGeneration.from_pretrained("grammarly/coedit-large") |
| |
|
| | configs = BaseModelConfigs.load("./configs.yaml") |
| |
|
| | |
| |
|
| | beam_settings = TTSettings(num_beams=5, min_length=1, max_length=100) |
| |
|
| | app = FastAPI() |
| |
|
| | origins = ["*"] |
| |
|
| | app.add_middleware( |
| | CORSMiddleware, |
| | allow_origins=origins, |
| | allow_credentials=True, |
| | allow_methods=["*"], |
| | allow_headers=["*"], |
| | ) |
| |
|
| |
|
| | class ImageToWordModel(OnnxInferenceModel): |
| | def __init__(self, char_list, *args, **kwargs): |
| | super().__init__(*args, **kwargs) |
| | self.char_list = char_list |
| |
|
| | def predict(self, image: np.ndarray): |
| | image = ImageResizer.resize_maintaining_aspect_ratio( |
| | image, *self.input_shape[:2][::-1] |
| | ) |
| |
|
| | image_pred = np.expand_dims(image, axis=0).astype(np.float32) |
| |
|
| | preds = self.model.run(None, {self.input_name: image_pred})[0] |
| |
|
| | text = ctc_decoder(preds, self.char_list)[0] |
| |
|
| | return text |
| |
|
| |
|
| | model = ImageToWordModel(model_path=configs.model_path, char_list=configs.vocab) |
| | extracted_text = "" |
| |
|
| | @app.post("/extract_handwritten_text/") |
| | async def predict_text(image: UploadFile): |
| | global extracted_text |
| | |
| | img = await image.read() |
| | nparr = np.frombuffer(img, np.uint8) |
| | img = cv2.imdecode(nparr, cv2.IMREAD_COLOR) |
| |
|
| | |
| | extracted_text = model.predict(img) |
| | |
| |
|
| | return {"text": extracted_text} |
| |
|
| |
|
| | @app.post("/extract_text/") |
| | async def extract_text_from_image(image: UploadFile): |
| | global extracted_text |
| | |
| | if image.content_type.startswith("image/"): |
| | |
| | image_bytes = await image.read() |
| | img = Image.open(io.BytesIO(image_bytes)) |
| |
|
| | |
| | extracted_text = pytesseract.image_to_string(img) |
| | |
| |
|
| | return {"text": extracted_text} |
| | else: |
| | return {"error": "Invalid file format. Please upload an image."} |
| |
|
| | class ChatPrompt(BaseModel): |
| | prompt: str |
| |
|
| | @app.post("/chat_prompt/") |
| | async def chat_prompt(request: ChatPrompt): |
| | global extracted_text |
| | input_text = request.prompt + ": " + extracted_text |
| | print(input_text) |
| | input_ids = tokenizer(input_text, return_tensors="pt").input_ids |
| | outputs = chatModel.generate(input_ids, max_length=256) |
| | edited_text = tokenizer.decode(outputs[0], skip_special_tokens=True) |
| |
|
| | return {"edited_text": edited_text} |
| |
|