OmarAbualrob commited on
Commit
bf69385
·
verified ·
1 Parent(s): fcf63e0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -47
app.py CHANGED
@@ -1,79 +1,55 @@
1
  from fastapi import FastAPI, File, UploadFile, HTTPException
2
  from fastapi.responses import JSONResponse
3
- from transformers import AutoProcessor, AutoModelForCausalLM, BitsAndBytesConfig
4
  from PIL import Image
5
- import torch
6
  import io
7
  import logging
8
 
9
- # --- 1. Basic Setup ---
10
  logging.basicConfig(level=logging.INFO)
11
  logger = logging.getLogger(__name__)
12
- app = FastAPI(title="Florence-2 OCR API", description="An API to extract text from images using the Florence-2-large model on a GPU.")
13
 
14
- # --- 2. Global Variables and Device Configuration ---
15
- device = "cuda" if torch.cuda.is_available() else "cpu"
16
- torch_dtype = torch.bfloat16
17
  model = None
18
  processor = None
19
 
20
- # --- 3. Model Loading Logic (at startup) ---
21
  @app.on_event("startup")
22
  async def startup_event():
23
  global model, processor
24
-
25
- if device == "cpu":
26
- logger.warning("CUDA not available, model will not be loaded. This API requires a GPU.")
27
- return
28
-
29
  try:
30
  logger.info(f"Using device: {device}")
31
- logger.info("Starting model loading process with 4-bit quantization...")
32
 
33
  model_id = "microsoft/Florence-2-large"
34
 
35
- quantization_config = BitsAndBytesConfig(
36
- load_in_4bit=True,
37
- bnb_4bit_compute_dtype=torch_dtype
38
- )
39
-
40
- # Load the model WITHOUT the invalid revision ID
41
  model = AutoModelForCausalLM.from_pretrained(
42
  model_id,
43
  trust_remote_code=True,
44
- quantization_config=quantization_config,
45
- # revision="e134b72" <-- REMOVED THIS LINE
46
  )
47
 
48
- # Load the processor WITHOUT the invalid revision ID
49
- processor = AutoProcessor.from_pretrained(
50
- model_id,
51
- trust_remote_code=True
52
- # revision="e134b72" <-- REMOVED THIS LINE
53
- )
54
 
55
- logger.info("Model and processor loaded successfully.")
56
 
57
  except Exception as e:
58
  logger.error(f"FATAL: An error occurred during model loading: {e}", exc_info=True)
59
 
60
- # --- 4. Define the OCR Task Function ---
61
  def run_ocr(image: Image.Image) -> str:
62
  if model is None or processor is None:
63
  raise RuntimeError("Model is not available. Check startup logs for loading errors.")
64
-
65
  if image.mode != "RGB":
66
  image = image.convert("RGB")
67
-
68
  prompt = "<OCR>"
69
  inputs = processor(text=prompt, images=image, return_tensors="pt")
70
 
71
- input_ids = inputs["input_ids"].to(device)
72
- pixel_values = inputs["pixel_values"].to(device, dtype=torch_dtype)
73
-
74
  generated_ids = model.generate(
75
- input_ids=input_ids,
76
- pixel_values=pixel_values,
77
  max_new_tokens=4096,
78
  do_sample=False,
79
  num_beams=3
@@ -81,30 +57,24 @@ def run_ocr(image: Image.Image) -> str:
81
 
82
  generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
83
  parsed_text = processor.post_process_generation(generated_text, task="<OCR>", image_size=(image.width, image.height))
84
-
85
  return parsed_text.get("<OCR>", "Error: Could not parse OCR output.")
86
 
87
 
88
- # --- 5. Create API Endpoints ---
 
89
  @app.post("/ocr", summary="Extract Text from Image")
90
  async def perform_ocr(file: UploadFile = File(..., description="Image file to perform OCR on.")):
91
  if model is None:
92
  raise HTTPException(status_code=503, detail="Model is not loaded or unavailable. Please check the server logs.")
93
-
94
  if not file.content_type.startswith("image/"):
95
- raise HTTPException(status_code=400, detail="Invalid file type. Please upload an image (e.g., PNG, JPG).")
96
-
97
  try:
98
  contents = await file.read()
99
  image = Image.open(io.BytesIO(contents))
100
-
101
  logger.info(f"Running OCR on uploaded file: {file.filename}")
102
  extracted_text = run_ocr(image)
103
  logger.info("OCR completed successfully.")
104
-
105
- return JSONResponse(
106
- content={"filename": file.filename, "text": extracted_text}
107
- )
108
  except Exception as e:
109
  logger.error(f"An error occurred during OCR processing for {file.filename}: {e}", exc_info=True)
110
  raise HTTPException(status_code=500, detail=f"An internal server error occurred: {str(e)}")
 
1
  from fastapi import FastAPI, File, UploadFile, HTTPException
2
  from fastapi.responses import JSONResponse
3
+ from transformers import AutoProcessor, AutoModelForCausalLM
4
  from PIL import Image
 
5
  import io
6
  import logging
7
 
 
8
  logging.basicConfig(level=logging.INFO)
9
  logger = logging.getLogger(__name__)
10
+ app = FastAPI(title="Florence-2 OCR API (CPU)", description="An API to extract text from images using the Florence-2-large model on CPU.")
11
 
12
+ # --- Global Variables and Device Configuration ---
13
+ device = "cpu" # Force CPU
 
14
  model = None
15
  processor = None
16
 
17
+ # --- Model Loading Logic (at startup) ---
18
  @app.on_event("startup")
19
  async def startup_event():
20
  global model, processor
 
 
 
 
 
21
  try:
22
  logger.info(f"Using device: {device}")
23
+ logger.info("Starting model loading process for CPU...")
24
 
25
  model_id = "microsoft/Florence-2-large"
26
 
27
+ # Load the model in full precision for CPU
 
 
 
 
 
28
  model = AutoModelForCausalLM.from_pretrained(
29
  model_id,
30
  trust_remote_code=True,
 
 
31
  )
32
 
33
+ processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
 
 
 
 
 
34
 
35
+ logger.info("Model and processor loaded successfully on CPU.")
36
 
37
  except Exception as e:
38
  logger.error(f"FATAL: An error occurred during model loading: {e}", exc_info=True)
39
 
40
+ # --- Define the OCR Task Function (CPU version) ---
41
  def run_ocr(image: Image.Image) -> str:
42
  if model is None or processor is None:
43
  raise RuntimeError("Model is not available. Check startup logs for loading errors.")
 
44
  if image.mode != "RGB":
45
  image = image.convert("RGB")
 
46
  prompt = "<OCR>"
47
  inputs = processor(text=prompt, images=image, return_tensors="pt")
48
 
49
+ # Generate on CPU (no .to(device) or dtype changes needed)
 
 
50
  generated_ids = model.generate(
51
+ input_ids=inputs["input_ids"],
52
+ pixel_values=inputs["pixel_values"],
53
  max_new_tokens=4096,
54
  do_sample=False,
55
  num_beams=3
 
57
 
58
  generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
59
  parsed_text = processor.post_process_generation(generated_text, task="<OCR>", image_size=(image.width, image.height))
 
60
  return parsed_text.get("<OCR>", "Error: Could not parse OCR output.")
61
 
62
 
63
+ # --- API Endpoints ---
64
+ # (Your @app.post and @app.get endpoints remain exactly the same)
65
  @app.post("/ocr", summary="Extract Text from Image")
66
  async def perform_ocr(file: UploadFile = File(..., description="Image file to perform OCR on.")):
67
  if model is None:
68
  raise HTTPException(status_code=503, detail="Model is not loaded or unavailable. Please check the server logs.")
 
69
  if not file.content_type.startswith("image/"):
70
+ raise HTTPException(status_code=400, detail="Invalid file type. Please upload an image.")
 
71
  try:
72
  contents = await file.read()
73
  image = Image.open(io.BytesIO(contents))
 
74
  logger.info(f"Running OCR on uploaded file: {file.filename}")
75
  extracted_text = run_ocr(image)
76
  logger.info("OCR completed successfully.")
77
+ return JSONResponse(content={"filename": file.filename, "text": extracted_text})
 
 
 
78
  except Exception as e:
79
  logger.error(f"An error occurred during OCR processing for {file.filename}: {e}", exc_info=True)
80
  raise HTTPException(status_code=500, detail=f"An internal server error occurred: {str(e)}")