MdSourav76046 commited on
Commit
5aa4532
·
verified ·
1 Parent(s): 06b80ad

Delete app.py

Browse files
Files changed (1) hide show
  1. app.py +0 -176
app.py DELETED
@@ -1,176 +0,0 @@
1
- """
2
- FastAPI Server for Text Correction
3
- Deploy this to run your text correction model as an API
4
- """
5
-
6
- from fastapi import FastAPI, HTTPException
7
- from fastapi.middleware.cors import CORSMiddleware
8
- from pydantic import BaseModel
9
- from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
10
- import torch
11
- import os
12
- from typing import Optional
13
-
14
- # Initialize FastAPI app
15
- app = FastAPI(
16
- title="Text Correction API",
17
- description="API for correcting OCR text using trained model",
18
- version="1.0.0"
19
- )
20
-
21
- # Add CORS middleware to allow requests from iOS app
22
- app.add_middleware(
23
- CORSMiddleware,
24
- allow_origins=["*"], # In production, specify your iOS app's domain
25
- allow_credentials=True,
26
- allow_methods=["*"],
27
- allow_headers=["*"],
28
- )
29
-
30
- # Global variables for model
31
- model = None
32
- tokenizer = None
33
- device = None
34
-
35
- # Pydantic models for request/response
36
- class TextRequest(BaseModel):
37
- text: str
38
-
39
- class TextResponse(BaseModel):
40
- corrected_text: str
41
- processing_time: float
42
-
43
- class HealthResponse(BaseModel):
44
- status: str
45
- model_loaded: bool
46
- device: str
47
-
48
- # Load model at startup
49
- @app.on_event("startup")
50
- async def load_model():
51
- global model, tokenizer, device
52
-
53
- print("🚀 Starting Text Correction API...")
54
-
55
- # Determine device
56
- device = "cuda" if torch.cuda.is_available() else "cpu"
57
- print(f"📱 Using device: {device}")
58
-
59
- # Load model and tokenizer
60
- try:
61
- model_path = os.getenv("MODEL_PATH", "./gpu_base_model2")
62
- print(f"📦 Loading model from: {model_path}")
63
-
64
- model = AutoModelForSeq2SeqLM.from_pretrained(model_path)
65
- tokenizer = AutoTokenizer.from_pretrained(model_path)
66
-
67
- # Move model to device
68
- model.to(device)
69
- model.eval()
70
-
71
- print("✅ Model loaded successfully!")
72
- print(f" - Model type: {type(model).__name__}")
73
- print(f" - Vocabulary size: {tokenizer.vocab_size}")
74
- print(f" - Device: {device}")
75
-
76
- except Exception as e:
77
- print(f"❌ Error loading model: {e}")
78
- print("⚠️ API will not work until model is loaded")
79
-
80
- # Health check endpoint
81
- @app.get("/health", response_model=HealthResponse)
82
- async def health_check():
83
- """Check if the API and model are ready"""
84
- return HealthResponse(
85
- status="healthy" if model is not None else "unhealthy",
86
- model_loaded=model is not None,
87
- device=device or "unknown"
88
- )
89
-
90
- # Text correction endpoint
91
- @app.post("/correct", response_model=TextResponse)
92
- async def correct_text(request: TextRequest):
93
- """
94
- Correct text using the trained model
95
-
96
- Args:
97
- request: TextRequest containing the text to correct
98
-
99
- Returns:
100
- TextResponse with corrected text and processing time
101
- """
102
- import time
103
-
104
- if model is None or tokenizer is None:
105
- raise HTTPException(
106
- status_code=503,
107
- detail="Model not loaded. Please wait for the model to initialize."
108
- )
109
-
110
- if not request.text or not request.text.strip():
111
- raise HTTPException(
112
- status_code=400,
113
- detail="Text cannot be empty"
114
- )
115
-
116
- start_time = time.time()
117
-
118
- try:
119
- # Tokenize input text
120
- inputs = tokenizer(
121
- request.text,
122
- return_tensors="pt",
123
- max_length=512,
124
- truncation=True,
125
- padding=True
126
- ).to(device)
127
-
128
- # Generate corrected text
129
- with torch.no_grad():
130
- outputs = model.generate(
131
- inputs.input_ids,
132
- attention_mask=inputs.attention_mask,
133
- max_length=512,
134
- num_beams=5,
135
- early_stopping=True,
136
- pad_token_id=tokenizer.pad_token_id,
137
- eos_token_id=tokenizer.eos_token_id
138
- )
139
-
140
- # Decode output
141
- corrected_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
142
-
143
- processing_time = time.time() - start_time
144
-
145
- print(f"✅ Text corrected in {processing_time:.2f}s")
146
- print(f" Input: {request.text[:50]}...")
147
- print(f" Output: {corrected_text[:50]}...")
148
-
149
- return TextResponse(
150
- corrected_text=corrected_text,
151
- processing_time=round(processing_time, 2)
152
- )
153
-
154
- except Exception as e:
155
- print(f"❌ Error during correction: {e}")
156
- raise HTTPException(
157
- status_code=500,
158
- detail=f"Text correction failed: {str(e)}"
159
- )
160
-
161
- # Root endpoint
162
- @app.get("/")
163
- async def root():
164
- return {
165
- "message": "Text Correction API",
166
- "version": "1.0.0",
167
- "endpoints": {
168
- "health": "/health",
169
- "correct": "/correct (POST)"
170
- }
171
- }
172
-
173
- if __name__ == "__main__":
174
- import uvicorn
175
- uvicorn.run(app, host="0.0.0.0", port=8000)
176
-