OmarAbualrob commited on
Commit
621c44f
·
verified ·
1 Parent(s): 6b29736

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +58 -56
app.py CHANGED
@@ -1,107 +1,109 @@
1
  from fastapi import FastAPI, File, UploadFile, HTTPException
2
  from fastapi.responses import JSONResponse
3
- from transformers import AutoProcessor, AutoModelForCausalLM, BitsAndBytesConfig
4
  from PIL import Image
5
  import torch
6
  import io
7
  import logging
8
 
9
- # --- 1. Basic Setup ---
10
  logging.basicConfig(level=logging.INFO)
11
  logger = logging.getLogger(__name__)
12
- app = FastAPI(title="Florence-2 OCR API", description="An API to extract text from images using the Florence-2-large model on a GPU.")
13
 
14
- # --- 2. Global Variables and Device Configuration ---
15
- device = "cuda" if torch.cuda.is_available() else "cpu"
16
- torch_dtype = torch.bfloat16
17
- model = None
18
- processor = None
19
 
20
- # --- 3. Model Loading Logic (at startup) ---
21
- @app.on_event("startup")
22
- async def startup_event():
23
- global model, processor
24
-
25
- if device == "cpu":
26
- logger.warning("CUDA not available, model will not be loaded. This API requires a GPU.")
27
- return
28
-
29
- try:
30
- logger.info(f"Using device: {device}")
31
- logger.info("Starting model loading process with 4-bit quantization...")
32
-
33
- model_id = "microsoft/Florence-2-large"
34
-
35
- quantization_config = BitsAndBytesConfig(
36
- load_in_4bit=True,
37
- bnb_4bit_compute_dtype=torch_dtype
38
- )
39
-
40
- model = AutoModelForCausalLM.from_pretrained(
41
- model_id,
42
- trust_remote_code=True,
43
- quantization_config=quantization_config,
44
- )
45
-
46
- processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
47
-
48
- logger.info("Model and processor loaded successfully.")
49
-
50
- except Exception as e:
51
- logger.error(f"FATAL: An error occurred during model loading: {e}", exc_info=True)
52
 
53
- # --- 4. Define the OCR Task Function ---
54
  def run_ocr(image: Image.Image) -> str:
 
 
 
55
  if model is None or processor is None:
56
- raise RuntimeError("Model is not available. Check startup logs for loading errors.")
57
 
 
58
  if image.mode != "RGB":
59
  image = image.convert("RGB")
60
 
 
61
  prompt = "<OCR>"
62
- inputs = processor(text=prompt, images=image, return_tensors="pt")
63
 
64
- input_ids = inputs["input_ids"].to(device)
65
- pixel_values = inputs["pixel_values"].to(device, dtype=torch_dtype)
66
 
 
 
67
  generated_ids = model.generate(
68
- input_ids=input_ids,
69
- pixel_values=pixel_values,
70
- max_new_tokens=4096,
71
- do_sample=False,
72
  num_beams=3
73
  )
74
 
 
75
  generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
 
 
 
76
  parsed_text = processor.post_process_generation(generated_text, task="<OCR>", image_size=(image.width, image.height))
77
 
78
  return parsed_text.get("<OCR>", "Error: Could not parse OCR output.")
79
 
80
 
81
- # --- 5. Create API Endpoints ---
82
  @app.post("/ocr", summary="Extract Text from Image")
83
  async def perform_ocr(file: UploadFile = File(..., description="Image file to perform OCR on.")):
 
 
 
 
84
  if model is None:
85
- raise HTTPException(status_code=503, detail="Model is not loaded or unavailable. Please check the server logs.")
86
 
 
87
  if not file.content_type.startswith("image/"):
88
- raise HTTPException(status_code=400, detail="Invalid file type. Please upload an image (e.g., PNG, JPG).")
89
 
90
  try:
 
91
  contents = await file.read()
92
  image = Image.open(io.BytesIO(contents))
93
 
94
- logger.info(f"Running OCR on uploaded file: {file.filename}")
 
95
  extracted_text = run_ocr(image)
96
  logger.info("OCR completed successfully.")
97
 
 
98
  return JSONResponse(
99
  content={"filename": file.filename, "text": extracted_text}
100
  )
 
101
  except Exception as e:
102
- logger.error(f"An error occurred during OCR processing for {file.filename}: {e}", exc_info=True)
103
- raise HTTPException(status_code=500, detail=f"An internal server error occurred: {str(e)}")
104
 
105
  @app.get("/", summary="Health Check")
106
  def read_root():
107
- return {"status": "ok", "model_loaded": model is not None, "device": device}
 
 
 
 
1
  from fastapi import FastAPI, File, UploadFile, HTTPException
2
  from fastapi.responses import JSONResponse
3
+ from transformers import AutoProcessor, AutoModelForCausalLM
4
  from PIL import Image
5
  import torch
6
  import io
7
  import logging
8
 
9
+ # Set up logging
10
  logging.basicConfig(level=logging.INFO)
11
  logger = logging.getLogger(__name__)
 
12
 
13
+ # --- 1. Initialize FastAPI App ---
14
+ app = FastAPI(title="Mixed-Content OCR API", description="An API to extract text from images containing both printed and handwritten text.")
 
 
 
15
 
16
+ # --- 2. Load the Model and Processor (at startup) ---
17
+ # This is a critical step. We load the model only once when the app starts.
18
+ # This prevents reloading the model on every API call, which would be very slow.
19
+ try:
20
+ logger.info("Loading model and processor...")
21
+ # Use the large model for better accuracy
22
+ model_id = "microsoft/Florence-2-large"
23
+ # NOTE: We need to trust remote code for Florence-2
24
+ model = AutoModelForCausalLM.from_pretrained(model_id, trust_remote_code=True)
25
+ processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
26
+ logger.info("Model and processor loaded successfully.")
27
+ except Exception as e:
28
+ logger.error(f"Error loading model: {e}")
29
+ # If the model fails to load, the API is not usable. We can't proceed.
30
+ model = None
31
+ processor = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
 
33
+ # --- 3. Define the OCR Task Function ---
34
  def run_ocr(image: Image.Image) -> str:
35
+ """
36
+ Performs OCR on a given PIL Image using the Florence-2 model.
37
+ """
38
  if model is None or processor is None:
39
+ raise RuntimeError("Model is not available. Check logs for loading errors.")
40
 
41
+ # Ensure image is in RGB format
42
  if image.mode != "RGB":
43
  image = image.convert("RGB")
44
 
45
+ # Define the task prompt
46
  prompt = "<OCR>"
 
47
 
48
+ # Preprocess the image and prompt
49
+ inputs = processor(text=prompt, images=image, return_tensors="pt")
50
 
51
+ # Generate text from the image
52
+ # Note: max_new_tokens can be adjusted based on expected text length
53
  generated_ids = model.generate(
54
+ input_ids=inputs["input_ids"],
55
+ pixel_values=inputs["pixel_values"],
56
+ max_new_tokens=4096, # Increased token limit for long documents
57
+ do_sample=False, # Use greedy decoding for deterministic output
58
  num_beams=3
59
  )
60
 
61
+ # Decode the generated IDs to a string
62
  generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
63
+
64
+ # Post-process the output to get the clean text
65
+ # The model's output for OCR is typically in the format: <OCR>extracted_text</s>
66
  parsed_text = processor.post_process_generation(generated_text, task="<OCR>", image_size=(image.width, image.height))
67
 
68
  return parsed_text.get("<OCR>", "Error: Could not parse OCR output.")
69
 
70
 
71
+ # --- 4. Create the API Endpoint ---
72
  @app.post("/ocr", summary="Extract Text from Image")
73
  async def perform_ocr(file: UploadFile = File(..., description="Image file to perform OCR on.")):
74
+ """
75
+ Takes an image file, extracts both printed and handwritten text,
76
+ and returns it as a JSON object.
77
+ """
78
  if model is None:
79
+ raise HTTPException(status_code=503, detail="Model is not loaded or unavailable.")
80
 
81
+ # Validate file type
82
  if not file.content_type.startswith("image/"):
83
+ raise HTTPException(status_code=400, detail="Invalid file type. Please upload an image.")
84
 
85
  try:
86
+ # Read the image content from the uploaded file
87
  contents = await file.read()
88
  image = Image.open(io.BytesIO(contents))
89
 
90
+ # Run the OCR task
91
+ logger.info("Running OCR on the uploaded image...")
92
  extracted_text = run_ocr(image)
93
  logger.info("OCR completed successfully.")
94
 
95
+ # Return the result
96
  return JSONResponse(
97
  content={"filename": file.filename, "text": extracted_text}
98
  )
99
+
100
  except Exception as e:
101
+ logger.error(f"An error occurred during OCR processing: {e}")
102
+ raise HTTPException(status_code=500, detail=f"An internal error occurred: {str(e)}")
103
 
104
  @app.get("/", summary="Health Check")
105
  def read_root():
106
+ """
107
+ A simple health check endpoint to confirm the API is running.
108
+ """
109
+ return {"status": "ok", "model_loaded": model is not None}