import os import random from pathlib import Path import numpy as np import tensorflow as tf from tensorflow import keras from tensorflow.keras import layers from fastapi import FastAPI, HTTPException from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel from contextlib import asynccontextmanager from fastapi.staticfiles import StaticFiles # New import for the pre-trained model from huggingface_hub import from_pretrained_keras # --- Pydantic Models for Request Body --- class CaptchaRequest(BaseModel): filename: str # --- Global Variables --- prediction_model = None num_to_char = None max_length = 5 # From your Gradio script # --- Configuration for the pre-trained "keras-io/ocr-for-captcha" model --- IMG_WIDTH = 200 IMG_HEIGHT = 50 # --- App Lifespan Management (Model Loading) --- @asynccontextmanager async def lifespan(app: FastAPI): global prediction_model, num_to_char try: print("INFO: Loading pre-trained Keras model and vocab...") # 1. Load the base model from Hugging Face Hub base_model = from_pretrained_keras("keras-io/ocr-for-captcha", compile=False) # 2. Create the inference-only prediction_model (from your Gradio script) prediction_model = keras.models.Model( base_model.get_layer(name="image").input, base_model.get_layer(name="dense2").output ) # 3. Load the vocabulary from the file with open("vocab.txt", "r") as f: vocab = f.read().splitlines() # 4. Create the character mapping layer (from your Gradio script) num_to_char = layers.StringLookup(vocabulary=vocab, mask_token=None, invert=True) print("INFO: Model and vocab loaded successfully.") except Exception as e: print(f"ERROR: Failed to load pre-trained model or vocab: {e}") prediction_model = None yield print("INFO: Application shutting down.") # Initialize the FastAPI app with the lifespan manager app = FastAPI(lifespan=lifespan) # --- CORS Middleware --- app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"]) app.mount("/static", StaticFiles(directory="static"), name="static") # --- Constants --- IMAGE_DIR = Path("static/images") # --- Helper Functions (from your Gradio script) --- def decode_batch_predictions(pred): # This function is directly from your Gradio script input_len = np.ones(pred.shape[0]) * pred.shape[1] results = keras.backend.ctc_decode(pred, input_length=input_len, greedy=True)[0][0][ :, :max_length ] output_text = [] for res in results: res = tf.strings.reduce_join(num_to_char(res)).numpy().decode("utf-8") output_text.append(res) return output_text # --- API Endpoints --- @app.get("/") async def read_root(): return {"message": "Welcome to the Captcha Solver API!"} @app.get("/get_captcha") async def get_captcha(): if not IMAGE_DIR.is_dir(): raise HTTPException(status_code=500, detail="Image directory not found.") image_files = [f for f in os.listdir(IMAGE_DIR) if f.endswith(('.png', '.jpg', '.jpeg'))] if not image_files: raise HTTPException(status_code=404, detail="No captcha images found.") return {"filename": random.choice(image_files)} @app.post("/solve_captcha") async def solve_captcha(request: CaptchaRequest): if prediction_model is None or num_to_char is None: raise HTTPException(status_code=503, detail="Model or vocab is not loaded.") image_path = IMAGE_DIR / request.filename if not image_path.is_file(): raise HTTPException(status_code=404, detail=f"File '{request.filename}' not found.") try: # This core logic is taken directly from your `classify_image` function # 1. Read image img = tf.io.read_file(str(image_path)) # Convert Path object to string for tf.io # 2. Decode and convert to grayscale img = tf.io.decode_png(img, channels=1) # 3. Convert to float32 in [0, 1] range img = tf.image.convert_image_dtype(img, tf.float32) # 4. Resize to the desired size img = tf.image.resize(img, [IMG_HEIGHT, IMG_WIDTH]) # 5. Transpose the image img = tf.transpose(img, perm=[1, 0, 2]) # 6. Add a batch dimension img = tf.expand_dims(img, axis=0) # 7. Get predictions preds = prediction_model.predict(img) # 8. Decode the predictions pred_text = decode_batch_predictions(preds) # Return the first (and only) prediction return {"prediction": pred_text[0]} except Exception as e: print(f"Error during prediction: {e}") raise HTTPException(status_code=500, detail=f"An error occurred during model inference: {e}")