OmarAbualrob commited on
Commit
2a409bc
·
verified ·
1 Parent(s): a039eef

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -20
app.py CHANGED
@@ -14,19 +14,31 @@ logger = logging.getLogger(__name__)
14
  app = FastAPI(title="Mixed-Content OCR API", description="An API to extract text from images containing both printed and handwritten text.")
15
 
16
  # --- 2. Load the Model and Processor (at startup) ---
17
- # This is a critical step. We load the model only once when the app starts.
18
- # This prevents reloading the model on every API call, which would be very slow.
 
 
 
 
 
 
19
  try:
20
  logger.info("Loading model and processor...")
21
- # Use the large model for better accuracy
22
  model_id = "microsoft/Florence-2-large"
23
- # NOTE: We need to trust remote code for Florence-2
24
- model = AutoModelForCausalLM.from_pretrained(model_id, trust_remote_code=True)
 
 
 
 
 
 
25
  processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
 
26
  logger.info("Model and processor loaded successfully.")
27
  except Exception as e:
28
  logger.error(f"Error loading model: {e}")
29
- # If the model fails to load, the API is not usable. We can't proceed.
30
  model = None
31
  processor = None
32
 
@@ -38,31 +50,27 @@ def run_ocr(image: Image.Image) -> str:
38
  if model is None or processor is None:
39
  raise RuntimeError("Model is not available. Check logs for loading errors.")
40
 
41
- # Ensure image is in RGB format
42
  if image.mode != "RGB":
43
  image = image.convert("RGB")
44
 
45
- # Define the task prompt
46
  prompt = "<OCR>"
47
 
48
  # Preprocess the image and prompt
49
  inputs = processor(text=prompt, images=image, return_tensors="pt")
50
 
51
- # Generate text from the image
52
- # Note: max_new_tokens can be adjusted based on expected text length
 
53
  generated_ids = model.generate(
54
  input_ids=inputs["input_ids"],
55
  pixel_values=inputs["pixel_values"],
56
- max_new_tokens=4096, # Increased token limit for long documents
57
- do_sample=False, # Use greedy decoding for deterministic output
58
  num_beams=3
59
  )
60
 
61
- # Decode the generated IDs to a string
62
  generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
63
 
64
- # Post-process the output to get the clean text
65
- # The model's output for OCR is typically in the format: <OCR>extracted_text</s>
66
  parsed_text = processor.post_process_generation(generated_text, task="<OCR>", image_size=(image.width, image.height))
67
 
68
  return parsed_text.get("<OCR>", "Error: Could not parse OCR output.")
@@ -78,21 +86,17 @@ async def perform_ocr(file: UploadFile = File(..., description="Image file to pe
78
  if model is None:
79
  raise HTTPException(status_code=503, detail="Model is not loaded or unavailable.")
80
 
81
- # Validate file type
82
  if not file.content_type.startswith("image/"):
83
  raise HTTPException(status_code=400, detail="Invalid file type. Please upload an image.")
84
 
85
  try:
86
- # Read the image content from the uploaded file
87
  contents = await file.read()
88
  image = Image.open(io.BytesIO(contents))
89
 
90
- # Run the OCR task
91
  logger.info("Running OCR on the uploaded image...")
92
  extracted_text = run_ocr(image)
93
  logger.info("OCR completed successfully.")
94
 
95
- # Return the result
96
  return JSONResponse(
97
  content={"filename": file.filename, "text": extracted_text}
98
  )
@@ -106,4 +110,4 @@ def read_root():
106
  """
107
  A simple health check endpoint to confirm the API is running.
108
  """
109
- return {"status": "ok", "model_loaded": model is not None}
 
14
  app = FastAPI(title="Mixed-Content OCR API", description="An API to extract text from images containing both printed and handwritten text.")
15
 
16
  # --- 2. Load the Model and Processor (at startup) ---
17
+
18
+ # A. Set up the device to use the GPU (T4) if available
19
+ device = "cuda" if torch.cuda.is_available() else "cpu"
20
+ logger.info(f"Using device: {device}")
21
+
22
+ # B. Use a memory-efficient dtype for the T4 GPU
23
+ torch_dtype = torch.bfloat16 # T4 GPUs are optimized for bfloat16
24
+
25
  try:
26
  logger.info("Loading model and processor...")
 
27
  model_id = "microsoft/Florence-2-large"
28
+
29
+ # C. Load the model with the specified dtype and send it to the GPU
30
+ model = AutoModelForCausalLM.from_pretrained(
31
+ model_id,
32
+ trust_remote_code=True,
33
+ torch_dtype=torch_dtype
34
+ ).to(device) # <-- Send the model to the GPU
35
+
36
  processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
37
+
38
  logger.info("Model and processor loaded successfully.")
39
  except Exception as e:
40
  logger.error(f"Error loading model: {e}")
41
+ # If the model fails to load, the API is not usable.
42
  model = None
43
  processor = None
44
 
 
50
  if model is None or processor is None:
51
  raise RuntimeError("Model is not available. Check logs for loading errors.")
52
 
 
53
  if image.mode != "RGB":
54
  image = image.convert("RGB")
55
 
 
56
  prompt = "<OCR>"
57
 
58
  # Preprocess the image and prompt
59
  inputs = processor(text=prompt, images=image, return_tensors="pt")
60
 
61
+ # D. IMPORTANT: Move the input tensors to the same device as the model (the GPU)
62
+ inputs = {k: v.to(device, dtype=torch_dtype if k == "pixel_values" else v.dtype) for k, v in inputs.items()}
63
+
64
  generated_ids = model.generate(
65
  input_ids=inputs["input_ids"],
66
  pixel_values=inputs["pixel_values"],
67
+ max_new_tokens=4096,
68
+ do_sample=False,
69
  num_beams=3
70
  )
71
 
 
72
  generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
73
 
 
 
74
  parsed_text = processor.post_process_generation(generated_text, task="<OCR>", image_size=(image.width, image.height))
75
 
76
  return parsed_text.get("<OCR>", "Error: Could not parse OCR output.")
 
86
  if model is None:
87
  raise HTTPException(status_code=503, detail="Model is not loaded or unavailable.")
88
 
 
89
  if not file.content_type.startswith("image/"):
90
  raise HTTPException(status_code=400, detail="Invalid file type. Please upload an image.")
91
 
92
  try:
 
93
  contents = await file.read()
94
  image = Image.open(io.BytesIO(contents))
95
 
 
96
  logger.info("Running OCR on the uploaded image...")
97
  extracted_text = run_ocr(image)
98
  logger.info("OCR completed successfully.")
99
 
 
100
  return JSONResponse(
101
  content={"filename": file.filename, "text": extracted_text}
102
  )
 
110
  """
111
  A simple health check endpoint to confirm the API is running.
112
  """
113
+ return {"status": "ok", "model_loaded": model is not None, "device": device}