| import os |
| import re |
| from fastapi import FastAPI, Request |
| from pydantic import BaseModel |
| from inference_onnx import get_transcription |
| import torch |
| import onnxruntime as ort |
| from config import * |
| from contextlib import asynccontextmanager |
|
|
| |
| @asynccontextmanager |
| async def lifespan(app: FastAPI): |
| print("🔧 Loading model...") |
|
|
| app.state.device = torch.device('cpu') |
| app.state.tokenizer = MODELS["./distilbert-base-multilingual-cased"][1].from_pretrained("./distilbert-base-multilingual-cased") |
| app.state.token_style = MODELS["./distilbert-base-multilingual-cased"][3] |
|
|
| onnx_model_path = "./poc_onnx_model_punctuation_batch.onnx" |
| providers = ['CPUExecutionProvider'] |
| |
| |
| sess_options = ort.SessionOptions() |
| app.state.session = ort.InferenceSession(onnx_model_path, providers=providers) |
|
|
| print("✅ ONNX model loaded into memory.") |
| yield |
| print("🧹 Shutting down...") |
|
|
| app = FastAPI(lifespan=lifespan) |
|
|
| punc_dict = { |
| '!': 'EXCLAMATION', |
| '?': 'QUESTION', |
| ',': 'COMMA', |
| ';': 'SEMICOLON', |
| ':': 'COLON', |
| '-': 'HYPHEN', |
| '।': 'DARI', |
| } |
| allowed_punctuations = set(punc_dict.keys()) |
|
|
| def clean_and_normalize_text(text, remove_punctuations=False): |
| """Clean and normalize Bangla text with correct spacing""" |
| if remove_punctuations: |
| |
| cleaned_text = re.sub(f"[{re.escape(''.join(allowed_punctuations))}]", "", text) |
| |
| cleaned_text = re.sub(r'\s+', ' ', cleaned_text).strip() |
| return cleaned_text |
| else: |
| |
| chunks = re.split(f"([{re.escape(''.join(allowed_punctuations))}])", text) |
| filtered_chunks = [] |
|
|
| for chunk in chunks: |
| if chunk in allowed_punctuations: |
| filtered_chunks.append(chunk) |
| else: |
| |
| clean_chunk = re.sub(rf"[^\u0980-\u09FF\u09E6-\u09EF\s]", "", chunk) |
| clean_chunk = re.sub(r'\s+', ' ', clean_chunk) |
| clean_chunk = clean_chunk.strip() |
| if clean_chunk: |
| filtered_chunks.append(' ' + clean_chunk) |
|
|
| |
| result = ''.join(filtered_chunks) |
| result = re.sub(r'\s+', ' ', result).strip() |
| return result |
|
|
| class TextInput(BaseModel): |
| text: str |
|
|
| @app.post("/punctuate") |
| async def punctuate_text(data: TextInput): |
| input_normalized = clean_and_normalize_text(data.text) |
| input_normalized = clean_and_normalize_text(input_normalized, remove_punctuations=True) |
| restored_text = get_transcription(input_normalized, app.state.session, app.state.tokenizer, app.state.device, app.state.token_style) |
| return {"restored_text": restored_text} |
|
|
| if __name__ == "__main__": |
| import uvicorn |
| uvicorn.run("api_onnx:app", host="0.0.0.0", port=5685, workers=1) |
|
|