|
|
from fastapi import FastAPI, File, UploadFile, HTTPException, Depends, status, Query |
|
|
from fastapi.responses import FileResponse |
|
|
from pydantic import BaseModel, EmailStr, Field |
|
|
from typing import Optional |
|
|
import cv2 |
|
|
import numpy as np |
|
|
import tensorflow as tf |
|
|
import pickle |
|
|
import matplotlib.pyplot as plt |
|
|
import matplotlib.font_manager as fm |
|
|
import os |
|
|
import io |
|
|
import sys |
|
|
import tempfile |
|
|
import requests |
|
|
from PIL import Image |
|
|
import uvicorn |
|
|
import shutil |
|
|
from pathlib import Path |
|
|
import py_text_scan |
|
|
from sqlalchemy import create_engine, Column, Integer, String, Boolean, Text, DateTime |
|
|
|
|
|
from sqlalchemy.ext.declarative import declarative_base |
|
|
from sqlalchemy.orm import sessionmaker, Session |
|
|
from passlib.context import CryptContext |
|
|
import datetime |
|
|
|
|
|
|
|
|
DATABASE_URL = "sqlite:///./test.db" |
|
|
engine = create_engine(DATABASE_URL, connect_args={"check_same_thread": False}) |
|
|
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) |
|
|
|
|
|
Base = declarative_base() |
|
|
|
|
|
|
|
|
class UserModel(Base): |
|
|
__tablename__ = "users" |
|
|
id = Column(Integer, primary_key=True, index=True) |
|
|
username = Column(String, unique=True, index=True) |
|
|
email = Column(String, unique=True, index=True) |
|
|
hashed_password = Column(String) |
|
|
is_active = Column(Boolean, default=True) |
|
|
is_admin = Column(Boolean, default=False) |
|
|
|
|
|
class FeedbackModel(Base): |
|
|
__tablename__ = "feedback" |
|
|
id = Column(Integer, primary_key=True, index=True) |
|
|
username = Column(String) |
|
|
comment = Column(Text) |
|
|
created_at = Column(DateTime, default=datetime.datetime.utcnow) |
|
|
|
|
|
Base.metadata.create_all(bind=engine) |
|
|
|
|
|
|
|
|
class OCRResponse(BaseModel): |
|
|
sakshi_output: str |
|
|
word_count: int |
|
|
prediction_label: str |
|
|
|
|
|
app = FastAPI( |
|
|
title="Dynamic Hindi OCR API", |
|
|
description="API for Hindi OCR with selectable models from the frontend.", |
|
|
version="1.1.0" |
|
|
) |
|
|
|
|
|
|
|
|
MODEL_URL = "https://huggingface.co/sameernotes/hindi-ocr/resolve/main/hindi_ocr_model.keras" |
|
|
ENCODER_URL = "https://huggingface.co/sameernotes/hindi-ocr/resolve/main/label_encoder.pkl" |
|
|
FONT_URL = "https://huggingface.co/sameernotes/hindi-ocr/resolve/main/NotoSansDevanagari-Regular.ttf" |
|
|
MODEL_PATH = "hindi_ocr_model.keras" |
|
|
ENCODER_PATH = "label_encoder.pkl" |
|
|
FONT_PATH = "NotoSansDevanagari-Regular.ttf" |
|
|
model = None |
|
|
label_encoder = None |
|
|
session_files = {} |
|
|
|
|
|
def download_file(url, dest): |
|
|
if not os.path.exists(dest): |
|
|
print(f"Downloading {dest}...") |
|
|
response = requests.get(url, stream=True) |
|
|
response.raise_for_status() |
|
|
with open(dest, 'wb') as f: |
|
|
for chunk in response.iter_content(chunk_size=8192): |
|
|
f.write(chunk) |
|
|
print(f"Downloaded {dest}") |
|
|
|
|
|
@app.on_event("startup") |
|
|
async def startup_event(): |
|
|
global model, label_encoder |
|
|
download_file(MODEL_URL, MODEL_PATH) |
|
|
download_file(ENCODER_URL, ENCODER_PATH) |
|
|
download_file(FONT_URL, FONT_PATH) |
|
|
if os.path.exists(FONT_PATH): |
|
|
fm.fontManager.addfont(FONT_PATH) |
|
|
plt.rcParams['font.family'] = 'Noto Sans Devanagari' |
|
|
model = tf.keras.models.load_model(MODEL_PATH) if os.path.exists(MODEL_PATH) else None |
|
|
if os.path.exists(ENCODER_PATH): |
|
|
with open(ENCODER_PATH, 'rb') as f: |
|
|
label_encoder = pickle.load(f) |
|
|
|
|
|
|
|
|
def detect_words(image): |
|
|
_, binary = cv2.threshold(image, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU) |
|
|
kernel = np.ones((3,3), np.uint8) |
|
|
dilated = cv2.dilate(binary, kernel, iterations=2) |
|
|
contours, _ = cv2.findContours(dilated, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) |
|
|
word_img = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR) |
|
|
word_count = 0 |
|
|
for contour in contours: |
|
|
x, y, w, h = cv2.boundingRect(contour) |
|
|
if w > 10 and h > 10: |
|
|
cv2.rectangle(word_img, (x, y), (x+w, y+h), (0, 255, 0), 2) |
|
|
word_count += 1 |
|
|
return word_img, word_count |
|
|
|
|
|
def run_py_text_scan(image_path): |
|
|
buffer = io.StringIO() |
|
|
old_stdout = sys.stdout |
|
|
sys.stdout = buffer |
|
|
try: |
|
|
py_text_scan.generate(image_path) |
|
|
finally: |
|
|
sys.stdout = old_stdout |
|
|
return buffer.getvalue() |
|
|
|
|
|
def process_image(image_array, use_keras: bool, use_py_text_scan: bool): |
|
|
img = cv2.cvtColor(image_array, cv2.COLOR_RGB2GRAY) |
|
|
word_detected_img, word_count = detect_words(img) |
|
|
word_detection_path = tempfile.NamedTemporaryFile(delete=False, suffix=".png").name |
|
|
cv2.imwrite(word_detection_path, word_detected_img) |
|
|
session_files['word_detection'] = word_detection_path |
|
|
|
|
|
|
|
|
pred_label = "Keras model disabled by user" |
|
|
if use_keras: |
|
|
try: |
|
|
img_resized = cv2.resize(img, (128, 32)) |
|
|
img_norm = img_resized / 255.0 |
|
|
img_input = img_norm[np.newaxis, ..., np.newaxis] |
|
|
if model is not None and label_encoder is not None: |
|
|
pred = model.predict(img_input) |
|
|
pred_label_idx = np.argmax(pred) |
|
|
pred_label = label_encoder.inverse_transform([pred_label_idx])[0] |
|
|
else: |
|
|
pred_label = "Keras model not loaded on server" |
|
|
except Exception as e: |
|
|
pred_label = f"Keras Error: {str(e)}" |
|
|
|
|
|
|
|
|
sakshi_output = "py_text_scan disabled by user" |
|
|
if use_py_text_scan: |
|
|
with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as tmp_file: |
|
|
cv2.imwrite(tmp_file.name, img) |
|
|
sakshi_output = run_py_text_scan(tmp_file.name) |
|
|
os.unlink(tmp_file.name) |
|
|
|
|
|
return { |
|
|
"sakshi_output": sakshi_output, |
|
|
"word_count": word_count, |
|
|
"prediction_label": pred_label |
|
|
} |
|
|
|
|
|
|
|
|
@app.post("/process/", response_model=OCRResponse) |
|
|
async def process( |
|
|
file: UploadFile = File(...), |
|
|
use_keras: bool = Query(True, description="Enable/disable the Keras model"), |
|
|
use_py_text_scan: bool = Query(True, description="Enable/disable the py_text_scan library") |
|
|
): |
|
|
if not file.content_type.startswith("image/"): |
|
|
raise HTTPException(status_code=400, detail="File must be an image") |
|
|
|
|
|
|
|
|
for key, filepath in session_files.items(): |
|
|
if os.path.exists(filepath): |
|
|
try: |
|
|
os.unlink(filepath) |
|
|
except: pass |
|
|
session_files.clear() |
|
|
|
|
|
|
|
|
temp_file_path = "" |
|
|
try: |
|
|
|
|
|
with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as temp_file: |
|
|
shutil.copyfileobj(file.file, temp_file) |
|
|
temp_file_path = temp_file.name |
|
|
|
|
|
image = Image.open(temp_file_path) |
|
|
image_array = np.array(image) |
|
|
|
|
|
|
|
|
result = process_image(image_array, use_keras, use_py_text_scan) |
|
|
|
|
|
return OCRResponse( |
|
|
sakshi_output=result["sakshi_output"], |
|
|
word_count=result["word_count"], |
|
|
prediction_label=result["prediction_label"] |
|
|
) |
|
|
except Exception as e: |
|
|
raise HTTPException(status_code=500, detail=f"Error processing image: {str(e)}") |
|
|
finally: |
|
|
|
|
|
if os.path.exists(temp_file_path): |
|
|
os.unlink(temp_file_path) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
uvicorn.run(app, host="0.0.0.0", port=8000) |