Spaces:
Runtime error
Runtime error
| """ | |
| FastAPI Server for Text Correction | |
| Deploy this to run your text correction model as an API | |
| """ | |
| from fastapi import FastAPI, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel | |
| from transformers import AutoModelForSeq2SeqLM, AutoTokenizer | |
| import torch | |
| import os | |
| from typing import Optional | |
| # Initialize FastAPI app | |
| app = FastAPI( | |
| title="Text Correction API", | |
| description="API for correcting OCR text using trained model", | |
| version="1.0.0" | |
| ) | |
| # Add CORS middleware to allow requests from iOS app | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], # In production, specify your iOS app's domain | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Global variables for model | |
| model = None | |
| tokenizer = None | |
| device = None | |
| # Pydantic models for request/response | |
| class TextRequest(BaseModel): | |
| text: str | |
| class TextResponse(BaseModel): | |
| corrected_text: str | |
| processing_time: float | |
| class HealthResponse(BaseModel): | |
| status: str | |
| model_loaded: bool | |
| device: str | |
| class Config: | |
| protected_namespaces = () | |
| # Load model at startup | |
| async def load_model(): | |
| global model, tokenizer, device | |
| print("π Starting Text Correction API...") | |
| # Set cache directory if not already set | |
| import os | |
| if not os.environ.get("TRANSFORMERS_CACHE"): | |
| os.environ["TRANSFORMERS_CACHE"] = "/tmp" | |
| if not os.environ.get("HF_HOME"): | |
| os.environ["HF_HOME"] = "/tmp" | |
| # Determine device | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| print(f"π± Using device: {device}") | |
| # Load model and tokenizer | |
| try: | |
| # Try to load from environment variable first | |
| model_path = os.getenv("MODEL_PATH") | |
| # If not set, try to load from local directory | |
| if not model_path: | |
| if os.path.exists("./gpu_base_model2"): | |
| model_path = "./gpu_base_model2" | |
| else: | |
| # If model not found locally, download from Hugging Face | |
| # This is your model repository on Hugging Face | |
| model_path = os.getenv("HF_MODEL_PATH", "MdSourav76046/TextCorrectionModel2") | |
| print(f"π₯ Model not found locally, will download from: {model_path}") | |
| print(" This may take a few minutes on first run...") | |
| print(f"π¦ Loading model from: {model_path}") | |
| model = AutoModelForSeq2SeqLM.from_pretrained(model_path) | |
| tokenizer = AutoTokenizer.from_pretrained(model_path) | |
| # Move model to device | |
| model.to(device) | |
| model.eval() | |
| print("β Model loaded successfully!") | |
| print(f" - Model type: {type(model).__name__}") | |
| print(f" - Vocabulary size: {tokenizer.vocab_size}") | |
| print(f" - Device: {device}") | |
| except Exception as e: | |
| print(f"β Error loading model: {e}") | |
| print("β οΈ API will not work until model is loaded") | |
| # Health check endpoint | |
| async def health_check(): | |
| """Check if the API and model are ready""" | |
| return HealthResponse( | |
| status="healthy" if model is not None else "unhealthy", | |
| model_loaded=model is not None, | |
| device=device or "unknown" | |
| ) | |
| # Text correction endpoint | |
| async def correct_text(request: TextRequest): | |
| """ | |
| Correct text using the trained model | |
| Args: | |
| request: TextRequest containing the text to correct | |
| Returns: | |
| TextResponse with corrected text and processing time | |
| """ | |
| import time | |
| if model is None or tokenizer is None: | |
| raise HTTPException( | |
| status_code=503, | |
| detail="Model not loaded. Please wait for the model to initialize." | |
| ) | |
| if not request.text or not request.text.strip(): | |
| raise HTTPException( | |
| status_code=400, | |
| detail="Text cannot be empty" | |
| ) | |
| start_time = time.time() | |
| try: | |
| # Tokenize input text | |
| inputs = tokenizer( | |
| request.text, | |
| return_tensors="pt", | |
| max_length=512, | |
| truncation=True, | |
| padding=True | |
| ).to(device) | |
| # Generate corrected text | |
| with torch.no_grad(): | |
| outputs = model.generate( | |
| inputs.input_ids, | |
| attention_mask=inputs.attention_mask, | |
| max_length=512, | |
| num_beams=5, | |
| early_stopping=True, | |
| pad_token_id=tokenizer.pad_token_id, | |
| eos_token_id=tokenizer.eos_token_id | |
| ) | |
| # Decode output | |
| corrected_text = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| processing_time = time.time() - start_time | |
| print(f"β Text corrected in {processing_time:.2f}s") | |
| print(f" Input: {request.text[:50]}...") | |
| print(f" Output: {corrected_text[:50]}...") | |
| return TextResponse( | |
| corrected_text=corrected_text, | |
| processing_time=round(processing_time, 2) | |
| ) | |
| except Exception as e: | |
| print(f"β Error during correction: {e}") | |
| raise HTTPException( | |
| status_code=500, | |
| detail=f"Text correction failed: {str(e)}" | |
| ) | |
| # Root endpoint | |
| async def root(): | |
| return { | |
| "message": "Text Correction API", | |
| "version": "1.0.0", | |
| "endpoints": { | |
| "health": "/health", | |
| "correct": "/correct (POST)" | |
| } | |
| } | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=8000) | |