Spaces:
Sleeping
Sleeping
| import os | |
| import random | |
| from pathlib import Path | |
| import numpy as np | |
| import tensorflow as tf | |
| from tensorflow import keras | |
| from tensorflow.keras import layers | |
| from fastapi import FastAPI, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel | |
| from contextlib import asynccontextmanager | |
| from fastapi.staticfiles import StaticFiles | |
| # New import for the pre-trained model | |
| from huggingface_hub import from_pretrained_keras | |
| # --- Pydantic Models for Request Body --- | |
| class CaptchaRequest(BaseModel): | |
| filename: str | |
| # --- Global Variables --- | |
| prediction_model = None | |
| num_to_char = None | |
| max_length = 5 # From your Gradio script | |
| # --- Configuration for the pre-trained "keras-io/ocr-for-captcha" model --- | |
| IMG_WIDTH = 200 | |
| IMG_HEIGHT = 50 | |
| # --- App Lifespan Management (Model Loading) --- | |
| async def lifespan(app: FastAPI): | |
| global prediction_model, num_to_char | |
| try: | |
| print("INFO: Loading pre-trained Keras model and vocab...") | |
| # 1. Load the base model from Hugging Face Hub | |
| base_model = from_pretrained_keras("keras-io/ocr-for-captcha", compile=False) | |
| # 2. Create the inference-only prediction_model (from your Gradio script) | |
| prediction_model = keras.models.Model( | |
| base_model.get_layer(name="image").input, base_model.get_layer(name="dense2").output | |
| ) | |
| # 3. Load the vocabulary from the file | |
| with open("vocab.txt", "r") as f: | |
| vocab = f.read().splitlines() | |
| # 4. Create the character mapping layer (from your Gradio script) | |
| num_to_char = layers.StringLookup(vocabulary=vocab, mask_token=None, invert=True) | |
| print("INFO: Model and vocab loaded successfully.") | |
| except Exception as e: | |
| print(f"ERROR: Failed to load pre-trained model or vocab: {e}") | |
| prediction_model = None | |
| yield | |
| print("INFO: Application shutting down.") | |
| # Initialize the FastAPI app with the lifespan manager | |
| app = FastAPI(lifespan=lifespan) | |
| # --- CORS Middleware --- | |
| app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"]) | |
| app.mount("/static", StaticFiles(directory="static"), name="static") | |
| # --- Constants --- | |
| IMAGE_DIR = Path("static/images") | |
| # --- Helper Functions (from your Gradio script) --- | |
| def decode_batch_predictions(pred): | |
| # This function is directly from your Gradio script | |
| input_len = np.ones(pred.shape[0]) * pred.shape[1] | |
| results = keras.backend.ctc_decode(pred, input_length=input_len, greedy=True)[0][0][ | |
| :, :max_length | |
| ] | |
| output_text = [] | |
| for res in results: | |
| res = tf.strings.reduce_join(num_to_char(res)).numpy().decode("utf-8") | |
| output_text.append(res) | |
| return output_text | |
| # --- API Endpoints --- | |
| async def read_root(): | |
| return {"message": "Welcome to the Captcha Solver API!"} | |
| async def get_captcha(): | |
| if not IMAGE_DIR.is_dir(): | |
| raise HTTPException(status_code=500, detail="Image directory not found.") | |
| image_files = [f for f in os.listdir(IMAGE_DIR) if f.endswith(('.png', '.jpg', '.jpeg'))] | |
| if not image_files: | |
| raise HTTPException(status_code=404, detail="No captcha images found.") | |
| return {"filename": random.choice(image_files)} | |
| async def solve_captcha(request: CaptchaRequest): | |
| if prediction_model is None or num_to_char is None: | |
| raise HTTPException(status_code=503, detail="Model or vocab is not loaded.") | |
| image_path = IMAGE_DIR / request.filename | |
| if not image_path.is_file(): | |
| raise HTTPException(status_code=404, detail=f"File '{request.filename}' not found.") | |
| try: | |
| # This core logic is taken directly from your `classify_image` function | |
| # 1. Read image | |
| img = tf.io.read_file(str(image_path)) # Convert Path object to string for tf.io | |
| # 2. Decode and convert to grayscale | |
| img = tf.io.decode_png(img, channels=1) | |
| # 3. Convert to float32 in [0, 1] range | |
| img = tf.image.convert_image_dtype(img, tf.float32) | |
| # 4. Resize to the desired size | |
| img = tf.image.resize(img, [IMG_HEIGHT, IMG_WIDTH]) | |
| # 5. Transpose the image | |
| img = tf.transpose(img, perm=[1, 0, 2]) | |
| # 6. Add a batch dimension | |
| img = tf.expand_dims(img, axis=0) | |
| # 7. Get predictions | |
| preds = prediction_model.predict(img) | |
| # 8. Decode the predictions | |
| pred_text = decode_batch_predictions(preds) | |
| # Return the first (and only) prediction | |
| return {"prediction": pred_text[0]} | |
| except Exception as e: | |
| print(f"Error during prediction: {e}") | |
| raise HTTPException(status_code=500, detail=f"An error occurred during model inference: {e}") |