|
|
from fastapi import FastAPI, File, UploadFile, HTTPException, Depends, status, Request |
|
|
from fastapi.responses import FileResponse, JSONResponse, HTMLResponse |
|
|
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm |
|
|
from pydantic import BaseModel, EmailStr, Field |
|
|
from typing import List, Optional |
|
|
import cv2 |
|
|
import numpy as np |
|
|
import tensorflow as tf |
|
|
import pickle |
|
|
import matplotlib.pyplot as plt |
|
|
import matplotlib.font_manager as fm |
|
|
import os |
|
|
import io |
|
|
import sys |
|
|
import tempfile |
|
|
import requests |
|
|
from PIL import Image |
|
|
import uvicorn |
|
|
import shutil |
|
|
from pathlib import Path |
|
|
import py_text_scan |
|
|
from sqlalchemy import create_engine, Column, Integer, String, Boolean, Text, DateTime |
|
|
from sqlalchemy.ext.declarative import declarative_base |
|
|
from sqlalchemy.orm import sessionmaker, Session |
|
|
from passlib.context import CryptContext |
|
|
import datetime |
|
|
|
|
|
|
|
|
DATABASE_URL = "sqlite:///./test.db" |
|
|
engine = create_engine(DATABASE_URL, connect_args={"check_same_thread": False}) |
|
|
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) |
|
|
Base = declarative_base() |
|
|
|
|
|
|
|
|
class UserModel(Base): |
|
|
__tablename__ = "users" |
|
|
id = Column(Integer, primary_key=True, index=True) |
|
|
username = Column(String, unique=True, index=True) |
|
|
email = Column(String, unique=True, index=True) |
|
|
hashed_password = Column(String) |
|
|
is_active = Column(Boolean, default=True) |
|
|
is_admin = Column(Boolean, default=False) |
|
|
|
|
|
class FeedbackModel(Base): |
|
|
__tablename__ = "feedback" |
|
|
id = Column(Integer, primary_key=True, index=True) |
|
|
username = Column(String) |
|
|
comment = Column(Text) |
|
|
created_at = Column(DateTime, default=datetime.datetime.utcnow) |
|
|
|
|
|
Base.metadata.create_all(bind=engine) |
|
|
|
|
|
|
|
|
class UserBase(BaseModel): |
|
|
username: str = Field(..., min_length=3, max_length=50) |
|
|
email: EmailStr |
|
|
|
|
|
class UserCreate(UserBase): |
|
|
password: str = Field(..., min_length=6) |
|
|
|
|
|
class UserResponse(UserBase): |
|
|
id: int |
|
|
is_active: bool |
|
|
is_admin: bool |
|
|
class Config: |
|
|
from_attributes = True |
|
|
|
|
|
class UserUpdate(BaseModel): |
|
|
username: Optional[str] = None |
|
|
email: Optional[EmailStr] = None |
|
|
is_active: Optional[bool] = None |
|
|
is_admin: Optional[bool] = None |
|
|
|
|
|
class FeedbackBase(BaseModel): |
|
|
username: str |
|
|
comment: str |
|
|
|
|
|
class FeedbackCreate(FeedbackBase): |
|
|
pass |
|
|
|
|
|
class FeedbackResponse(FeedbackBase): |
|
|
id: int |
|
|
created_at: datetime.datetime |
|
|
class Config: |
|
|
from_attributes = True |
|
|
|
|
|
class Token(BaseModel): |
|
|
access_token: str |
|
|
token_type: str |
|
|
|
|
|
class TokenData(BaseModel): |
|
|
username: Optional[str] = None |
|
|
|
|
|
class OCRResponse(BaseModel): |
|
|
sakshi_output: str |
|
|
word_count: int |
|
|
prediction_label: str |
|
|
|
|
|
|
|
|
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") |
|
|
|
|
|
|
|
|
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") |
|
|
|
|
|
def get_db(): |
|
|
db = SessionLocal() |
|
|
try: |
|
|
yield db |
|
|
finally: |
|
|
db.close() |
|
|
|
|
|
async def get_current_user(db: Session = Depends(get_db), token: str = Depends(oauth2_scheme)): |
|
|
user = get_user_by_username(db, username=token) |
|
|
if not user: |
|
|
raise HTTPException( |
|
|
status_code=status.HTTP_401_UNAUTHORIZED, |
|
|
detail="Invalid authentication credentials", |
|
|
headers={"WWW-Authenticate": "Bearer"}, |
|
|
) |
|
|
return user |
|
|
|
|
|
async def get_current_active_user(current_user: UserModel = Depends(get_current_user)): |
|
|
if not current_user.is_active: |
|
|
raise HTTPException(status_code=400, detail="Inactive user") |
|
|
return current_user |
|
|
|
|
|
async def get_current_admin_user(current_user: UserModel = Depends(get_current_active_user)): |
|
|
if not current_user.is_admin: |
|
|
raise HTTPException(status_code=403, detail="Not an administrator") |
|
|
return current_user |
|
|
|
|
|
|
|
|
def get_user(db: Session, user_id: int): |
|
|
return db.query(UserModel).filter(UserModel.id == user_id).first() |
|
|
|
|
|
def get_user_by_username(db: Session, username: str): |
|
|
return db.query(UserModel).filter(UserModel.username == username).first() |
|
|
|
|
|
def get_user_by_email(db: Session, email: str): |
|
|
return db.query(UserModel).filter(UserModel.email == email).first() |
|
|
|
|
|
def get_users(db: Session, skip: int = 0, limit: int = 100): |
|
|
return db.query(UserModel).offset(skip).limit(limit).all() |
|
|
|
|
|
def create_user(db: Session, user: UserCreate): |
|
|
hashed_password = pwd_context.hash(user.password) |
|
|
db_user = UserModel(username=user.username, email=user.email, hashed_password=hashed_password) |
|
|
db.add(db_user) |
|
|
db.commit() |
|
|
db.refresh(db_user) |
|
|
return db_user |
|
|
|
|
|
def update_user(db: Session, user_id: int, user: UserUpdate): |
|
|
db_user = get_user(db, user_id) |
|
|
if db_user: |
|
|
for key, value in user.dict(exclude_unset=True).items(): |
|
|
setattr(db_user, key, value) |
|
|
db.commit() |
|
|
db.refresh(db_user) |
|
|
return db_user |
|
|
|
|
|
def delete_user(db: Session, user_id: int): |
|
|
db_user = get_user(db, user_id) |
|
|
if db_user: |
|
|
db.delete(db_user) |
|
|
db.commit() |
|
|
return True |
|
|
return False |
|
|
|
|
|
def verify_password(plain_password, hashed_password): |
|
|
return pwd_context.verify(plain_password, hashed_password) |
|
|
|
|
|
def create_feedback(db: Session, feedback: FeedbackCreate): |
|
|
db_feedback = FeedbackModel(**feedback.dict()) |
|
|
db.add(db_feedback) |
|
|
db.commit() |
|
|
db.refresh(db_feedback) |
|
|
return db_feedback |
|
|
|
|
|
def get_feedback(db: Session, skip: int = 0, limit: int = 100): |
|
|
return db.query(FeedbackModel).order_by(FeedbackModel.created_at.desc()).offset(skip).limit(limit).all() |
|
|
|
|
|
|
|
|
app = FastAPI( |
|
|
title="Hindi OCR API", |
|
|
description="API for Hindi OCR, word detection, authentication, and feedback", |
|
|
version="1.0.0" |
|
|
) |
|
|
|
|
|
|
|
|
MODEL_URL = "https://huggingface.co/sameernotes/hindi-ocr/resolve/main/hindi_ocr_model.keras" |
|
|
ENCODER_URL = "https://huggingface.co/sameernotes/hindi-ocr/resolve/main/label_encoder.pkl" |
|
|
FONT_URL = "https://huggingface.co/sameernotes/hindi-ocr/resolve/main/NotoSansDevanagari-Regular.ttf" |
|
|
MODEL_PATH = "hindi_ocr_model.keras" |
|
|
ENCODER_PATH = "label_encoder.pkl" |
|
|
FONT_PATH = "NotoSansDevanagari-Regular.ttf" |
|
|
|
|
|
def download_file(url, dest): |
|
|
if not os.path.exists(dest): |
|
|
print(f"Downloading {dest}...") |
|
|
response = requests.get(url, stream=True) |
|
|
response.raise_for_status() |
|
|
with open(dest, 'wb') as f: |
|
|
for chunk in response.iter_content(chunk_size=8192): |
|
|
f.write(chunk) |
|
|
print(f"Downloaded {dest}") |
|
|
|
|
|
def load_model(): |
|
|
if not os.path.exists(MODEL_PATH): |
|
|
return None |
|
|
return tf.keras.models.load_model(MODEL_PATH) |
|
|
|
|
|
def load_label_encoder(): |
|
|
if not os.path.exists(ENCODER_PATH): |
|
|
return None |
|
|
with open(ENCODER_PATH, 'rb') as f: |
|
|
return pickle.load(f) |
|
|
|
|
|
model = None |
|
|
label_encoder = None |
|
|
session_files = {} |
|
|
|
|
|
@app.on_event("startup") |
|
|
async def startup_event(): |
|
|
global model, label_encoder |
|
|
download_file(MODEL_URL, MODEL_PATH) |
|
|
download_file(ENCODER_URL, ENCODER_PATH) |
|
|
download_file(FONT_URL, FONT_PATH) |
|
|
|
|
|
if os.path.exists(FONT_PATH): |
|
|
fm.fontManager.addfont(FONT_PATH) |
|
|
plt.rcParams['font.family'] = 'Noto Sans Devanagari' |
|
|
model = load_model() |
|
|
label_encoder = load_label_encoder() |
|
|
|
|
|
db = SessionLocal() |
|
|
if not get_user_by_username(db, "admin"): |
|
|
admin_user = UserCreate(username="admin", email="admin@example.com", password="adminpassword") |
|
|
create_user(db, admin_user) |
|
|
admin = get_user_by_username(db, "admin") |
|
|
admin.is_admin = True |
|
|
db.commit() |
|
|
db.close() |
|
|
|
|
|
def detect_words(image): |
|
|
_, binary = cv2.threshold(image, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU) |
|
|
kernel = np.ones((3,3), np.uint8) |
|
|
dilated = cv2.dilate(binary, kernel, iterations=2) |
|
|
contours, _ = cv2.findContours(dilated, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) |
|
|
word_img = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR) |
|
|
word_count = 0 |
|
|
for contour in contours: |
|
|
x, y, w, h = cv2.boundingRect(contour) |
|
|
if w > 10 and h > 10: |
|
|
cv2.rectangle(word_img, (x, y), (x+w, y+h), (0, 255, 0), 2) |
|
|
word_count += 1 |
|
|
return word_img, word_count |
|
|
|
|
|
def run_py_text_scan(image_path): |
|
|
buffer = io.StringIO() |
|
|
old_stdout = sys.stdout |
|
|
sys.stdout = buffer |
|
|
try: |
|
|
py_text_scan.generate(image_path) |
|
|
finally: |
|
|
sys.stdout = old_stdout |
|
|
return buffer.getvalue() |
|
|
|
|
|
def process_image(image_array): |
|
|
img = cv2.cvtColor(image_array, cv2.COLOR_RGB2GRAY) |
|
|
word_detected_img, word_count = detect_words(img) |
|
|
word_detection_path = tempfile.NamedTemporaryFile(delete=False, suffix=".png").name |
|
|
cv2.imwrite(word_detection_path, word_detected_img) |
|
|
session_files['word_detection'] = word_detection_path |
|
|
|
|
|
pred_path = None |
|
|
try: |
|
|
img_resized = cv2.resize(img, (128, 32)) |
|
|
img_norm = img_resized / 255.0 |
|
|
img_input = img_norm[np.newaxis, ..., np.newaxis] |
|
|
if model is not None and label_encoder is not None: |
|
|
pred = model.predict(img_input) |
|
|
pred_label_idx = np.argmax(pred) |
|
|
pred_label = label_encoder.inverse_transform([pred_label_idx])[0] |
|
|
fig, ax = plt.subplots() |
|
|
ax.imshow(img, cmap='gray') |
|
|
ax.set_title(f"Predicted: {pred_label}", fontsize=12) |
|
|
ax.axis('off') |
|
|
pred_path = tempfile.NamedTemporaryFile(delete=False, suffix=".png").name |
|
|
plt.savefig(pred_path) |
|
|
plt.close() |
|
|
session_files['prediction'] = pred_path |
|
|
else: |
|
|
pred_label = "Model or encoder not loaded" |
|
|
except Exception as e: |
|
|
pred_label = f"Error: {str(e)}" |
|
|
|
|
|
with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as tmp_file: |
|
|
cv2.imwrite(tmp_file.name, img) |
|
|
sakshi_output = run_py_text_scan(tmp_file.name) |
|
|
os.unlink(tmp_file.name) |
|
|
return { |
|
|
"sakshi_output": sakshi_output, |
|
|
"word_detection_path": word_detection_path if 'word_detection' in session_files else None, |
|
|
"word_count": word_count, |
|
|
"prediction_path": pred_path if 'prediction' in session_files else None, |
|
|
"prediction_label": pred_label |
|
|
} |
|
|
|
|
|
@app.post("/token", response_model=Token) |
|
|
async def login_for_access_token(form_data: OAuth2PasswordRequestForm = Depends(), db: Session = Depends(get_db)): |
|
|
user = get_user_by_username(db, form_data.username) |
|
|
if not user or not verify_password(form_data.password, user.hashed_password): |
|
|
raise HTTPException( |
|
|
status_code=status.HTTP_401_UNAUTHORIZED, |
|
|
detail="Incorrect username or password", |
|
|
headers={"WWW-Authenticate": "Bearer"}, |
|
|
) |
|
|
access_token = user.username |
|
|
return {"access_token": access_token, "token_type": "bearer"} |
|
|
|
|
|
@app.post("/signup", response_model=UserResponse) |
|
|
async def signup(user: UserCreate, db: Session = Depends(get_db)): |
|
|
db_user_username = get_user_by_username(db, username=user.username) |
|
|
if db_user_username: |
|
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Username already registered") |
|
|
db_user_email = get_user_by_email(db, email=user.email) |
|
|
if db_user_email: |
|
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Email already registered") |
|
|
created = create_user(db=db, user=user) |
|
|
return created |
|
|
|
|
|
@app.post("/process/", response_model=OCRResponse) |
|
|
async def process(file: UploadFile = File(...), current_user: UserModel = Depends(get_current_active_user)): |
|
|
if not file.content_type.startswith("image/"): |
|
|
raise HTTPException(status_code=400, detail="File must be an image") |
|
|
|
|
|
for key, filepath in session_files.items(): |
|
|
if os.path.exists(filepath): |
|
|
try: |
|
|
os.unlink(filepath) |
|
|
except: |
|
|
pass |
|
|
session_files.clear() |
|
|
|
|
|
temp_file = tempfile.NamedTemporaryFile(delete=False) |
|
|
try: |
|
|
with temp_file as f: |
|
|
shutil.copyfileobj(file.file, f) |
|
|
image = Image.open(temp_file.name) |
|
|
image_array = np.array(image) |
|
|
result = process_image(image_array) |
|
|
return OCRResponse( |
|
|
sakshi_output=result["sakshi_output"], |
|
|
word_count=result["word_count"], |
|
|
prediction_label=result["prediction_label"] |
|
|
) |
|
|
except Exception as e: |
|
|
raise HTTPException(status_code=500, detail=f"Error processing image: {str(e)}") |
|
|
finally: |
|
|
os.unlink(temp_file.name) |
|
|
|
|
|
@app.get("/word-detection/") |
|
|
async def get_word_detection(current_user: UserModel = Depends(get_current_active_user)): |
|
|
if 'word_detection' not in session_files or not os.path.exists(session_files['word_detection']): |
|
|
raise HTTPException(status_code=404, detail="Word detection image not found") |
|
|
return FileResponse(session_files['word_detection']) |
|
|
|
|
|
@app.get("/prediction/") |
|
|
async def get_prediction(current_user: UserModel = Depends(get_current_active_user)): |
|
|
if 'prediction' not in session_files or not os.path.exists(session_files['prediction']): |
|
|
raise HTTPException(status_code=404, detail="Prediction image not found") |
|
|
return FileResponse(session_files['prediction']) |
|
|
|
|
|
|
|
|
|
|
|
@app.post("/feedback/", response_model=FeedbackResponse) |
|
|
async def create_feedback_route(feedback: FeedbackCreate, db: Session = Depends(get_db)): |
|
|
return create_feedback(db=db, feedback=feedback) |
|
|
|
|
|
|
|
|
@app.get("/admin/users/") |
|
|
async def admin_get_users(skip: int = 0, limit: int = 100, current_user: UserModel = Depends(get_current_admin_user), db: Session = Depends(get_db)): |
|
|
return get_users(db, skip=skip, limit=limit) |
|
|
|
|
|
@app.delete("/admin/users/{user_id}") |
|
|
async def admin_delete_user(user_id: int, current_user: UserModel = Depends(get_current_admin_user), db: Session = Depends(get_db)): |
|
|
if delete_user(db, user_id): |
|
|
return {"detail": "User deleted successfully"} |
|
|
raise HTTPException(status_code=404, detail="User not found") |
|
|
|
|
|
@app.get("/admin/feedback/") |
|
|
async def admin_get_feedback(skip: int = 0, limit: int = 100, current_user: UserModel = Depends(get_current_admin_user), db: Session = Depends(get_db)): |
|
|
return get_feedback(db, skip=skip, limit=limit) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
uvicorn.run(app, host="0.0.0.0", port=8000) |
|
|
|