File size: 4,888 Bytes
7017186
 
 
6bc7bcc
 
caa7998
 
7017186
1eaa426
6bc7bcc
 
2ac0e47
 
1eaa426
caa7998
 
 
6bc7bcc
 
 
 
 
 
caa7998
 
6bc7bcc
caa7998
6bc7bcc
 
 
 
 
 
caa7998
6bc7bcc
caa7998
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6bc7bcc
caa7998
6bc7bcc
 
 
 
 
 
 
1eaa426
 
caa7998
1eaa426
2ac0e47
 
7017186
 
 
caa7998
6bc7bcc
caa7998
 
6bc7bcc
caa7998
 
 
 
 
 
 
6bc7bcc
7017186
1eaa426
 
 
7017186
 
 
 
 
6bc7bcc
 
7017186
 
6bc7bcc
 
 
 
caa7998
 
6bc7bcc
 
 
 
 
 
caa7998
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6bc7bcc
caa7998
 
7017186
6bc7bcc
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
import os
import random
from pathlib import Path
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from contextlib import asynccontextmanager
from fastapi.staticfiles import StaticFiles


# New import for the pre-trained model
from huggingface_hub import from_pretrained_keras

# --- Pydantic Models for Request Body ---
class CaptchaRequest(BaseModel):
    filename: str

# --- Global Variables ---
prediction_model = None
num_to_char = None
max_length = 5 # From your Gradio script

# --- Configuration for the pre-trained "keras-io/ocr-for-captcha" model ---
IMG_WIDTH = 200
IMG_HEIGHT = 50

# --- App Lifespan Management (Model Loading) ---
@asynccontextmanager
async def lifespan(app: FastAPI):
    global prediction_model, num_to_char
    try:
        print("INFO:     Loading pre-trained Keras model and vocab...")
        
        # 1. Load the base model from Hugging Face Hub
        base_model = from_pretrained_keras("keras-io/ocr-for-captcha", compile=False)
        
        # 2. Create the inference-only prediction_model (from your Gradio script)
        prediction_model = keras.models.Model(
            base_model.get_layer(name="image").input, base_model.get_layer(name="dense2").output
        )

        # 3. Load the vocabulary from the file
        with open("vocab.txt", "r") as f:
            vocab = f.read().splitlines()
        
        # 4. Create the character mapping layer (from your Gradio script)
        num_to_char = layers.StringLookup(vocabulary=vocab, mask_token=None, invert=True)

        print("INFO:     Model and vocab loaded successfully.")

    except Exception as e:
        print(f"ERROR:    Failed to load pre-trained model or vocab: {e}")
        prediction_model = None
    yield
    print("INFO:     Application shutting down.")


# Initialize the FastAPI app with the lifespan manager
app = FastAPI(lifespan=lifespan)

# --- CORS Middleware ---
app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"])

app.mount("/static", StaticFiles(directory="static"), name="static")

# --- Constants ---
IMAGE_DIR = Path("static/images")

# --- Helper Functions (from your Gradio script) ---

def decode_batch_predictions(pred):
    # This function is directly from your Gradio script
    input_len = np.ones(pred.shape[0]) * pred.shape[1]
    results = keras.backend.ctc_decode(pred, input_length=input_len, greedy=True)[0][0][
        :, :max_length
    ]
    output_text = []
    for res in results:
        res = tf.strings.reduce_join(num_to_char(res)).numpy().decode("utf-8")
        output_text.append(res)
    return output_text

# --- API Endpoints ---
@app.get("/")
async def read_root():
    return {"message": "Welcome to the Captcha Solver API!"}

@app.get("/get_captcha")
async def get_captcha():
    if not IMAGE_DIR.is_dir():
        raise HTTPException(status_code=500, detail="Image directory not found.")
    image_files = [f for f in os.listdir(IMAGE_DIR) if f.endswith(('.png', '.jpg', '.jpeg'))]
    if not image_files:
        raise HTTPException(status_code=404, detail="No captcha images found.")
    return {"filename": random.choice(image_files)}

@app.post("/solve_captcha")
async def solve_captcha(request: CaptchaRequest):
    if prediction_model is None or num_to_char is None:
        raise HTTPException(status_code=503, detail="Model or vocab is not loaded.")

    image_path = IMAGE_DIR / request.filename
    if not image_path.is_file():
        raise HTTPException(status_code=404, detail=f"File '{request.filename}' not found.")

    try:
        # This core logic is taken directly from your `classify_image` function
        
        # 1. Read image
        img = tf.io.read_file(str(image_path)) # Convert Path object to string for tf.io
        # 2. Decode and convert to grayscale
        img = tf.io.decode_png(img, channels=1)
        # 3. Convert to float32 in [0, 1] range
        img = tf.image.convert_image_dtype(img, tf.float32)
        # 4. Resize to the desired size
        img = tf.image.resize(img, [IMG_HEIGHT, IMG_WIDTH])
        # 5. Transpose the image
        img = tf.transpose(img, perm=[1, 0, 2])
        # 6. Add a batch dimension
        img = tf.expand_dims(img, axis=0)
        
        # 7. Get predictions
        preds = prediction_model.predict(img)
        
        # 8. Decode the predictions
        pred_text = decode_batch_predictions(preds)
        
        # Return the first (and only) prediction
        return {"prediction": pred_text[0]}

    except Exception as e:
        print(f"Error during prediction: {e}")
        raise HTTPException(status_code=500, detail=f"An error occurred during model inference: {e}")