csaf-captcha / app /main.py
aziac's picture
mount
2ac0e47
import os
import random
from pathlib import Path
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from contextlib import asynccontextmanager
from fastapi.staticfiles import StaticFiles
# New import for the pre-trained model
from huggingface_hub import from_pretrained_keras
# --- Pydantic Models for Request Body ---
class CaptchaRequest(BaseModel):
filename: str
# --- Global Variables ---
prediction_model = None
num_to_char = None
max_length = 5 # From your Gradio script
# --- Configuration for the pre-trained "keras-io/ocr-for-captcha" model ---
IMG_WIDTH = 200
IMG_HEIGHT = 50
# --- App Lifespan Management (Model Loading) ---
@asynccontextmanager
async def lifespan(app: FastAPI):
global prediction_model, num_to_char
try:
print("INFO: Loading pre-trained Keras model and vocab...")
# 1. Load the base model from Hugging Face Hub
base_model = from_pretrained_keras("keras-io/ocr-for-captcha", compile=False)
# 2. Create the inference-only prediction_model (from your Gradio script)
prediction_model = keras.models.Model(
base_model.get_layer(name="image").input, base_model.get_layer(name="dense2").output
)
# 3. Load the vocabulary from the file
with open("vocab.txt", "r") as f:
vocab = f.read().splitlines()
# 4. Create the character mapping layer (from your Gradio script)
num_to_char = layers.StringLookup(vocabulary=vocab, mask_token=None, invert=True)
print("INFO: Model and vocab loaded successfully.")
except Exception as e:
print(f"ERROR: Failed to load pre-trained model or vocab: {e}")
prediction_model = None
yield
print("INFO: Application shutting down.")
# Initialize the FastAPI app with the lifespan manager
app = FastAPI(lifespan=lifespan)
# --- CORS Middleware ---
app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"])
app.mount("/static", StaticFiles(directory="static"), name="static")
# --- Constants ---
IMAGE_DIR = Path("static/images")
# --- Helper Functions (from your Gradio script) ---
def decode_batch_predictions(pred):
# This function is directly from your Gradio script
input_len = np.ones(pred.shape[0]) * pred.shape[1]
results = keras.backend.ctc_decode(pred, input_length=input_len, greedy=True)[0][0][
:, :max_length
]
output_text = []
for res in results:
res = tf.strings.reduce_join(num_to_char(res)).numpy().decode("utf-8")
output_text.append(res)
return output_text
# --- API Endpoints ---
@app.get("/")
async def read_root():
return {"message": "Welcome to the Captcha Solver API!"}
@app.get("/get_captcha")
async def get_captcha():
if not IMAGE_DIR.is_dir():
raise HTTPException(status_code=500, detail="Image directory not found.")
image_files = [f for f in os.listdir(IMAGE_DIR) if f.endswith(('.png', '.jpg', '.jpeg'))]
if not image_files:
raise HTTPException(status_code=404, detail="No captcha images found.")
return {"filename": random.choice(image_files)}
@app.post("/solve_captcha")
async def solve_captcha(request: CaptchaRequest):
if prediction_model is None or num_to_char is None:
raise HTTPException(status_code=503, detail="Model or vocab is not loaded.")
image_path = IMAGE_DIR / request.filename
if not image_path.is_file():
raise HTTPException(status_code=404, detail=f"File '{request.filename}' not found.")
try:
# This core logic is taken directly from your `classify_image` function
# 1. Read image
img = tf.io.read_file(str(image_path)) # Convert Path object to string for tf.io
# 2. Decode and convert to grayscale
img = tf.io.decode_png(img, channels=1)
# 3. Convert to float32 in [0, 1] range
img = tf.image.convert_image_dtype(img, tf.float32)
# 4. Resize to the desired size
img = tf.image.resize(img, [IMG_HEIGHT, IMG_WIDTH])
# 5. Transpose the image
img = tf.transpose(img, perm=[1, 0, 2])
# 6. Add a batch dimension
img = tf.expand_dims(img, axis=0)
# 7. Get predictions
preds = prediction_model.predict(img)
# 8. Decode the predictions
pred_text = decode_batch_predictions(preds)
# Return the first (and only) prediction
return {"prediction": pred_text[0]}
except Exception as e:
print(f"Error during prediction: {e}")
raise HTTPException(status_code=500, detail=f"An error occurred during model inference: {e}")