|
|
from fastapi import FastAPI, File, UploadFile, Form, HTTPException |
|
|
from fastapi.middleware.cors import CORSMiddleware |
|
|
from fastapi.responses import JSONResponse |
|
|
from typing import Optional, Any, Dict, Union |
|
|
import shutil |
|
|
import os |
|
|
import json |
|
|
from loguru import logger |
|
|
from pathlib import Path |
|
|
import tempfile |
|
|
import numpy as np |
|
|
from datetime import datetime |
|
|
import base64 |
|
|
from io import BytesIO |
|
|
from PIL import Image |
|
|
from main import RenAITranscription |
|
|
|
|
|
app = FastAPI(title="RenAI Transcription API", version="1.0.0") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ALLOWED_EXTENSIONS = {".jpg", ".jpeg", ".png", ".bmp", ".tiff", ".webp"} |
|
|
MAX_FILE_SIZE = 10 * 1024 * 1024 |
|
|
|
|
|
def numpy_to_base64(array: np.ndarray, format: str = 'PNG', quality: int = 85) -> str: |
|
|
""" |
|
|
Convert numpy array (image) to base64 encoded string for web display. |
|
|
|
|
|
Args: |
|
|
array: Numpy array representing the image |
|
|
format: Image format ('PNG' or 'JPEG') |
|
|
quality: JPEG quality (1-100), only used if format is JPEG |
|
|
|
|
|
Returns: |
|
|
Data URI string that can be directly used in HTML <img> src attribute |
|
|
""" |
|
|
try: |
|
|
|
|
|
img = Image.fromarray(array) |
|
|
|
|
|
|
|
|
buffer = BytesIO() |
|
|
if format.upper() == 'JPEG': |
|
|
|
|
|
if img.mode in ('RGBA', 'LA', 'P'): |
|
|
background = Image.new('RGB', img.size, (255, 255, 255)) |
|
|
if img.mode == 'P': |
|
|
img = img.convert('RGBA') |
|
|
background.paste(img, mask=img.split()[-1] if img.mode in ('RGBA', 'LA') else None) |
|
|
img = background |
|
|
img.save(buffer, format='JPEG', quality=quality, optimize=True) |
|
|
mime_type = 'image/jpeg' |
|
|
else: |
|
|
img.save(buffer, format='PNG', optimize=True) |
|
|
mime_type = 'image/png' |
|
|
|
|
|
|
|
|
img_str = base64.b64encode(buffer.getvalue()).decode('utf-8') |
|
|
return f"data:{mime_type};base64,{img_str}" |
|
|
except Exception as e: |
|
|
logger.error(f"Error converting numpy array to base64: {e}") |
|
|
return None |
|
|
|
|
|
def format_transcription_result(result: Dict, include_images: bool = False, image_format: str = 'PNG') -> Dict[str, Any]: |
|
|
""" |
|
|
Format transcription result into a structured response. |
|
|
|
|
|
Args: |
|
|
result: Dictionary with line IDs as keys, each containing 'image' and 'transcription' |
|
|
include_images: Whether to include base64 encoded images in response |
|
|
image_format: Image format for base64 encoding ('PNG' or 'JPEG') |
|
|
|
|
|
Returns: |
|
|
Formatted dictionary with transcription data |
|
|
""" |
|
|
formatted_lines = {} |
|
|
transcription_text = [] |
|
|
|
|
|
for line_id, line_data in result.items(): |
|
|
formatted_line = { |
|
|
'line_id': line_id, |
|
|
'transcription': line_data.get('transcription', '') |
|
|
} |
|
|
|
|
|
|
|
|
if include_images and 'image' in line_data: |
|
|
image_array = line_data['image'] |
|
|
if isinstance(image_array, np.ndarray): |
|
|
image_base64 = numpy_to_base64(image_array, format=image_format) |
|
|
if image_base64: |
|
|
formatted_line['image'] = image_base64 |
|
|
|
|
|
formatted_lines[line_id] = formatted_line |
|
|
transcription_text.append(f"{line_id}: {line_data.get('transcription', '')}") |
|
|
|
|
|
return { |
|
|
'lines': formatted_lines, |
|
|
'full_text': '\n'.join(transcription_text), |
|
|
'total_lines': len(result) |
|
|
} |
|
|
|
|
|
@app.get("/") |
|
|
def home(): |
|
|
return { |
|
|
"message": "Hello, RenAI!", |
|
|
"version": "1.0.0", |
|
|
"endpoints": { |
|
|
"transcribe": "/renai-transcribe (POST)", |
|
|
"transcribe_base64": "/renai-transcribe-base64 (POST)", |
|
|
"health": "/health (GET)" |
|
|
} |
|
|
} |
|
|
|
|
|
@app.post("/renai-transcribe") |
|
|
async def transcription_endpoint( |
|
|
image: UploadFile = File(..., description="Image file to transcribe"), |
|
|
userToken: Optional[str] = Form(None, description="User authentication token"), |
|
|
post_processing_enabled: bool = Form(False, description="Enable post-processing"), |
|
|
unet_enabled: bool = Form(False, description="Enable UNet processing"), |
|
|
include_images: bool = Form(True, description="Include base64 encoded line images in response"), |
|
|
image_format: str = Form("JPEG", description="Image format for line images: PNG or JPEG") |
|
|
): |
|
|
""" |
|
|
Upload an image file and get transcription results. |
|
|
|
|
|
- **image**: Image file (JPG, PNG, BMP, TIFF, WebP) |
|
|
- **userToken**: Optional user authentication token |
|
|
- **post_processing_enabled**: Enable/disable post-processing |
|
|
- **unet_enabled**: Enable/disable UNet processing |
|
|
- **include_images**: Include base64 encoded images of each line (web-ready format) |
|
|
- **image_format**: Format for line images: 'PNG' (higher quality, larger) or 'JPEG' (smaller, faster) |
|
|
""" |
|
|
start_time = datetime.now() |
|
|
logger.info(f"Transcription request received for file: {image.filename} by userToken: {userToken if userToken else 'Anonymous'}") |
|
|
|
|
|
|
|
|
if not image.filename: |
|
|
raise HTTPException(status_code=400, detail="No file provided") |
|
|
|
|
|
file_extension = Path(image.filename).suffix.lower() |
|
|
if file_extension not in ALLOWED_EXTENSIONS: |
|
|
raise HTTPException( |
|
|
status_code=400, |
|
|
detail=f"Invalid file type. Allowed types: {', '.join(ALLOWED_EXTENSIONS)}" |
|
|
) |
|
|
|
|
|
|
|
|
if image.size and image.size > MAX_FILE_SIZE: |
|
|
raise HTTPException( |
|
|
status_code=400, |
|
|
detail=f"File too large. Maximum size: {MAX_FILE_SIZE // (1024*1024)}MB" |
|
|
) |
|
|
|
|
|
temp_file_path = None |
|
|
try: |
|
|
|
|
|
with tempfile.NamedTemporaryFile(delete=False, suffix=file_extension) as temp_file: |
|
|
shutil.copyfileobj(image.file, temp_file) |
|
|
temp_file_path = temp_file.name |
|
|
|
|
|
logger.info(f"Processing image: {temp_file_path}") |
|
|
|
|
|
|
|
|
result = RenAITranscription( |
|
|
image=temp_file_path, |
|
|
post_processing_enabled=post_processing_enabled, |
|
|
unet_enabled=unet_enabled |
|
|
) |
|
|
|
|
|
logger.info(f"Transcription completed. Result type: {type(result)}, Lines: {len(result)}") |
|
|
|
|
|
|
|
|
formatted_result = format_transcription_result(result, include_images=include_images, image_format=image_format) |
|
|
|
|
|
|
|
|
os.unlink(temp_file_path) |
|
|
|
|
|
processing_time = (datetime.now() - start_time).total_seconds() |
|
|
logger.info(f"Request completed in {processing_time:.2f}s") |
|
|
|
|
|
response_data = { |
|
|
"success": True, |
|
|
"filename": image.filename, |
|
|
"transcription": formatted_result, |
|
|
"metadata": { |
|
|
"processing_time_seconds": round(processing_time, 2), |
|
|
"timestamp": datetime.now().isoformat(), |
|
|
"total_lines": formatted_result['total_lines'], |
|
|
"parameters": { |
|
|
"post_processing_enabled": post_processing_enabled, |
|
|
"unet_enabled": unet_enabled, |
|
|
"include_images": include_images, |
|
|
"userToken": userToken if userToken else "Anonymous" |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
return JSONResponse(content=response_data) |
|
|
|
|
|
except Exception as e: |
|
|
|
|
|
if temp_file_path and os.path.exists(temp_file_path): |
|
|
try: |
|
|
os.unlink(temp_file_path) |
|
|
except: |
|
|
pass |
|
|
|
|
|
logger.error(f"Transcription failed: {e}") |
|
|
|
|
|
raise HTTPException( |
|
|
status_code=500, |
|
|
detail={ |
|
|
"error": str(e), |
|
|
"type": type(e).__name__ |
|
|
} |
|
|
) |
|
|
|
|
|
@app.post("/renai-transcribe-base64") |
|
|
async def transcription_base64_endpoint( |
|
|
image_data: str = Form(..., description="Base64 encoded image data"), |
|
|
userToken: Optional[str] = Form(None, description="User authentication token"), |
|
|
post_processing_enabled: bool = Form(False, description="Enable post-processing"), |
|
|
unet_enabled: bool = Form(False, description="Enable UNet processing"), |
|
|
include_images: bool = Form(False, description="Include base64 encoded line images in response"), |
|
|
image_format: str = Form("JPEG", description="Image format for line images: PNG or JPEG") |
|
|
): |
|
|
""" |
|
|
Alternative endpoint that accepts base64 encoded image data. |
|
|
""" |
|
|
import base64 |
|
|
import io |
|
|
from PIL import Image |
|
|
|
|
|
start_time = datetime.now() |
|
|
logger.info(f"Base64 transcription request received by userToken: {userToken if userToken else 'Anonymous'}") |
|
|
|
|
|
temp_file_path = None |
|
|
try: |
|
|
|
|
|
if "," in image_data: |
|
|
image_data = image_data.split(",", 1)[1] |
|
|
|
|
|
|
|
|
image_bytes = base64.b64decode(image_data) |
|
|
image_pil = Image.open(io.BytesIO(image_bytes)) |
|
|
|
|
|
|
|
|
with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as temp_file: |
|
|
image_pil.save(temp_file.name) |
|
|
temp_file_path = temp_file.name |
|
|
|
|
|
logger.info(f"Processing base64 image: {temp_file_path}") |
|
|
|
|
|
|
|
|
result = RenAITranscription( |
|
|
image=temp_file_path, |
|
|
post_processing_enabled=post_processing_enabled, |
|
|
unet_enabled=unet_enabled |
|
|
) |
|
|
|
|
|
|
|
|
formatted_result = format_transcription_result(result, include_images=include_images, image_format=image_format) |
|
|
|
|
|
|
|
|
os.unlink(temp_file_path) |
|
|
|
|
|
processing_time = (datetime.now() - start_time).total_seconds() |
|
|
logger.info(f"Base64 request completed in {processing_time:.2f}s") |
|
|
|
|
|
response_data = { |
|
|
"success": True, |
|
|
"transcription": formatted_result, |
|
|
"metadata": { |
|
|
"processing_time_seconds": round(processing_time, 2), |
|
|
"timestamp": datetime.now().isoformat(), |
|
|
"total_lines": formatted_result['total_lines'], |
|
|
"parameters": { |
|
|
"post_processing_enabled": post_processing_enabled, |
|
|
"unet_enabled": unet_enabled, |
|
|
"include_images": include_images, |
|
|
"image_format": image_format if include_images else None, |
|
|
"userToken": userToken if userToken else "Anonymous" |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
return JSONResponse(content=response_data) |
|
|
|
|
|
except Exception as e: |
|
|
if temp_file_path and os.path.exists(temp_file_path): |
|
|
try: |
|
|
os.unlink(temp_file_path) |
|
|
except: |
|
|
pass |
|
|
|
|
|
logger.error(f"Base64 transcription failed: {e}") |
|
|
|
|
|
raise HTTPException( |
|
|
status_code=500, |
|
|
detail={ |
|
|
"error": str(e), |
|
|
"type": type(e).__name__ |
|
|
} |
|
|
) |
|
|
|
|
|
@app.get("/health") |
|
|
def health_check(): |
|
|
try: |
|
|
return { |
|
|
"status": "healthy", |
|
|
"service": "RenAI Transcription API", |
|
|
"timestamp": datetime.now().isoformat() |
|
|
} |
|
|
except Exception as e: |
|
|
logger.error(f"Health check failed: {e}") |
|
|
raise HTTPException(status_code=500, detail="Service unhealthy") |