Spaces:
Paused
Paused
File size: 4,259 Bytes
523b7a6 545491d 523b7a6 545491d 523b7a6 545491d 523b7a6 545491d 523b7a6 621c44f 523b7a6 621c44f 523b7a6 621c44f 2a409bc 523b7a6 621c44f 523b7a6 621c44f 523b7a6 f8797e4 621c44f 545491d f8797e4 621c44f 523b7a6 621c44f 545491d 621c44f 523b7a6 545491d 523b7a6 545491d 523b7a6 621c44f 523b7a6 545491d 523b7a6 621c44f 523b7a6 621c44f 523b7a6 545491d 523b7a6 545491d 523b7a6 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 | import io
import torch
from fastapi import FastAPI, File, UploadFile, HTTPException
from PIL import Image
from transformers import AutoModelForCausalLM, AutoProcessor
# --- 1. SCRIPT SETUP ---
# Set up device (use GPU if available, otherwise CPU)
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"--- Running on {DEVICE} ---")
# Define model and processor IDs from Hugging Face Hub
MODEL_ID = "microsoft/Florence-2-large"
# For better performance, you can use the float16 version if your hardware supports it
# MODEL_ID = "microsoft/Florence-2-large-ft"
# --- 2. LOAD MODEL AND PROCESSOR ---
# Load the model and processor from Hugging Face
# trust_remote_code=True is required for Florence-2
# torch_dtype=torch.float16 is used for faster inference and lower memory on GPUs
try:
model = AutoModelForCausalLM.from_pretrained(MODEL_ID, trust_remote_code=True, torch_dtype=torch.float16).to(DEVICE)
processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)
print("--- Model and processor loaded successfully ---")
except Exception as e:
print(f"--- Error loading model: {e} ---")
model = None
processor = None
# --- 3. FASTAPI APP INITIALIZATION ---
app = FastAPI(
title="Florence-2 OCR API",
description="An API for extracting text from images using Microsoft's Florence-2-large model. "
"Handles both printed and handwritten text.",
version="1.0.0"
)
# --- 4. HELPER FUNCTION ---
def run_florence2_ocr(image: Image.Image):
"""
Runs the Florence-2 model to perform OCR on a given image.
Args:
image (Image.Image): The input image in PIL format.
Returns:
str: The extracted text.
"""
if not model or not processor:
raise HTTPException(status_code=503, detail="Model is not available. Please check server logs.")
# The task prompt for OCR
task_prompt = "<OCR>"
# Ensure image is in RGB format
if image.mode != "RGB":
image = image.convert("RGB")
# Preprocess the image and prompt
inputs = processor(text=task_prompt, images=image, return_tensors="pt").to(DEVICE)
# Move inputs to float16 if the model is in float16
if model.dtype == torch.float16:
inputs = inputs.to(torch.float16)
# Generate text from the image
generated_ids = model.generate(
input_ids=inputs["input_ids"],
pixel_values=inputs["pixel_values"],
max_new_tokens=2048, # Increased token limit for dense text
num_beams=3,
do_sample=False # Use greedy decoding for more deterministic results
)
# Decode the generated IDs to text
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
# Parse the output to get only the OCR result
# The model's output format is typically "<OCR>extracted_text</s>"
# We remove the prompt and the end-of-sequence token
parsed_text = processor.post_process_generation(generated_text, task=task_prompt, image_size=(image.width, image.height))
return parsed_text.get('<OCR>', "Error: Could not parse OCR output.")
# --- 5. API ENDPOINTS ---
@app.get("/", summary="Root Endpoint", description="Returns a welcome message.")
def read_root():
return {"message": "Welcome to the Florence-2 OCR API. Go to /docs for usage."}
@app.post("/ocr", summary="Extract Text from Image", description="Upload an image file to extract text. Supports both computer and handwritten text.")
async def extract_text_from_image(file: UploadFile = File(..., description="Image file to process.")):
"""
Endpoint to perform OCR on an uploaded image.
"""
# Read image content from the uploaded file
try:
contents = await file.read()
image = Image.open(io.BytesIO(contents))
except Exception:
raise HTTPException(status_code=400, detail="Invalid image file. Could not open image.")
# Run the OCR model
try:
extracted_text = run_florence2_ocr(image)
return {"filename": file.filename, "extracted_text": extracted_text}
except Exception as e:
print(f"Error during model inference: {e}")
raise HTTPException(status_code=500, detail=f"An error occurred during processing: {str(e)}") |