Spaces:
Sleeping
Sleeping
| import torch | |
| import onnx | |
| import onnxruntime as rt | |
| from torchvision import transforms as T | |
| from PIL import Image | |
| from tokenizer_base import Tokenizer | |
| import pathlib | |
| import os | |
| from fastapi import FastAPI, HTTPException | |
| from pydantic import BaseModel | |
| import base64 | |
| from io import BytesIO | |
| from huggingface_hub import hf_hub_download | |
| import shutil | |
| cwd = pathlib.Path(__file__).parent.resolve() | |
| model_dir = os.path.join(cwd, "secret_models") | |
| model_file = os.path.join(model_dir, "captcha.onnx") | |
| # Créer le dossier s'il n'existe pas | |
| os.makedirs(model_dir, exist_ok=True) | |
| # Télécharger le modèle depuis Hugging Face si nécessaire | |
| if not os.path.exists(model_file): | |
| print("Downloading model from Hugging Face...") | |
| try: | |
| downloaded_file = hf_hub_download( | |
| repo_id="docparser/captcha", | |
| filename="captcha.onnx", | |
| repo_type="model", | |
| token=True | |
| ) | |
| shutil.copy(downloaded_file, model_file) | |
| print(f"Model downloaded to {model_file}") | |
| except Exception as e: | |
| print(f"Error downloading model: {e}") | |
| # Si le fichier existe déjà dans le dossier, on continue | |
| if not os.path.exists(model_file): | |
| raise | |
| img_size = (32, 128) | |
| charset = r"0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ!\"#$%&'()*+,-./:;<=>?@[\\]^_`{|}~" | |
| tokenizer_base = Tokenizer(charset) | |
| app = FastAPI(title="Text Captcha Reader API") | |
| def get_transform(img_size): | |
| transforms = [] | |
| transforms.extend([ | |
| T.Resize(img_size, T.InterpolationMode.BICUBIC), | |
| T.ToTensor(), | |
| T.Normalize(0.5, 0.5) | |
| ]) | |
| return T.Compose(transforms) | |
| def to_numpy(tensor): | |
| return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy() | |
| def initialize_model(model_file): | |
| transform = get_transform(img_size) | |
| # Onnx model loading | |
| onnx_model = onnx.load(model_file) | |
| onnx.checker.check_model(onnx_model) | |
| ort_session = rt.InferenceSession(model_file) | |
| return transform, ort_session | |
| def get_text(img_org): | |
| # Preprocess. Model expects a batch of images with shape: (B, C, H, W) | |
| x = transform(img_org.convert('RGB')).unsqueeze(0) | |
| # compute ONNX Runtime output prediction | |
| ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(x)} | |
| logits = ort_session.run(None, ort_inputs)[0] | |
| probs = torch.tensor(logits).softmax(-1) | |
| preds, probs = tokenizer_base.decode(probs) | |
| preds = preds[0] | |
| print(preds) | |
| return preds | |
| # Initialize model at startup | |
| transform, ort_session = initialize_model(model_file=model_file) | |
| # Pydantic model for request | |
| class ImageRequest(BaseModel): | |
| image: str # base64 encoded image | |
| async def predict_captcha(request: ImageRequest): | |
| try: | |
| # Decode base64 image | |
| image_data = base64.b64decode(request.image) | |
| img = Image.open(BytesIO(image_data)) | |
| # Get prediction | |
| text = get_text(img) | |
| return { | |
| "success": True, | |
| "text": text | |
| } | |
| except Exception as e: | |
| return { | |
| "success": False, | |
| "error": str(e) | |
| } | |
| async def health_check(): | |
| return {"status": "ok"} | |
| def read_root(): | |
| return {"message": "API is running!"} |