Vdv26's picture
Upload 3 files
9497244 verified
Raw
History Blame Contribute Delete
1.65 kB
from fastapi import FastAPI, UploadFile, File
from fastapi.middleware.cors import CORSMiddleware
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
from PIL import Image
import torch
import io
app = FastAPI()
# Enable CORS so the React frontend can communicate with this API
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # In production, change this to your frontend URL
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
device = torch.device("cpu") # Use CPU for standard web hosting unless paying for GPU servers
print("Loading model into server memory...")
# Point this to your fine-tuned local folder, or the base model if testing
model_path = "Vdv26/trocr-captcha-finetuned"
processor = TrOCRProcessor.from_pretrained(model_path)
model = VisionEncoderDecoderModel.from_pretrained(model_path).to(device)
@app.post("/api/predict")
async def predict_captcha(file: UploadFile = File(...)):
# 1. Read the uploaded image bytes from the internet
contents = await file.read()
image = Image.open(io.BytesIO(contents)).convert("RGB")
# 2. Run inference
pixel_values = processor(images=image, return_tensors="pt").pixel_values.to(device)
with torch.no_grad():
generated_ids = model.generate(pixel_values, max_new_tokens=10)
prediction = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
# 3. Return the JSON response to the frontend
return {"filename": file.filename, "prediction": prediction.replace(' ', '')}
# Run locally using: uvicorn main:app --reload