Namra-Satva commited on
Commit
8992be6
·
verified ·
1 Parent(s): de50ef9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +31 -12
app.py CHANGED
@@ -4,6 +4,10 @@ import shutil
4
  import os
5
  from fastapi.middleware.cors import CORSMiddleware
6
  import uuid
 
 
 
 
7
  os.environ["MPLCONFIGDIR"] = "/tmp/matplotlib"
8
  os.environ["YOLO_CONFIG_DIR"] = "/tmp/ultralytics"
9
  os.environ["XDG_CACHE_HOME"] = "/tmp"
@@ -13,13 +17,11 @@ os.makedirs("/tmp/matplotlib", exist_ok=True)
13
  os.makedirs("/tmp/ultralytics", exist_ok=True)
14
  os.makedirs("/tmp/fontconfig", exist_ok=True)
15
 
16
- from model_utils import extract_invoice_data_from_image
17
-
18
  app = FastAPI()
19
 
20
  UPLOAD_DIR = "/tmp/uploads"
21
  os.makedirs(UPLOAD_DIR, exist_ok=True)
22
- ALLOWED_EXTENSIONS = {".png", ".jpg", ".jpeg"}
23
  app.add_middleware(
24
  CORSMiddleware,
25
  allow_origins=["*"],
@@ -28,23 +30,38 @@ app.add_middleware(
28
  allow_headers=["*"],
29
  )
30
 
 
 
 
 
 
 
31
  @app.post("/extract-invoice")
32
  async def extract_invoice(file: UploadFile = File(...)):
33
  file_ext = os.path.splitext(file.filename)[-1].lower()
34
 
35
  if file_ext not in ALLOWED_EXTENSIONS:
36
- raise HTTPException(status_code=400, detail="Please upload Jpeg, Jpg or Png images only.")
37
 
38
- # Save file to disk
39
  unique_filename = f"{uuid.uuid4().hex}{file_ext}"
40
- file_location = os.path.join(UPLOAD_DIR, unique_filename)
41
 
42
  try:
43
- with open(file_location, "wb") as f:
 
44
  shutil.copyfileobj(file.file, f)
45
 
46
- # Process the file with your model
47
- extracted_data = extract_invoice_data_from_image(file_location)
 
 
 
 
 
 
 
 
 
48
 
49
  return JSONResponse(content={
50
  "success": True,
@@ -59,6 +76,8 @@ async def extract_invoice(file: UploadFile = File(...)):
59
  )
60
 
61
  finally:
62
- # Clean up the file if it exists
63
- if os.path.exists(file_location):
64
- os.remove(file_location)
 
 
 
4
  import os
5
  from fastapi.middleware.cors import CORSMiddleware
6
  import uuid
7
+ from pdf2image import convert_from_path
8
+ from PIL import Image
9
+ from model_utils import extract_invoice_data_from_image
10
+
11
  os.environ["MPLCONFIGDIR"] = "/tmp/matplotlib"
12
  os.environ["YOLO_CONFIG_DIR"] = "/tmp/ultralytics"
13
  os.environ["XDG_CACHE_HOME"] = "/tmp"
 
17
  os.makedirs("/tmp/ultralytics", exist_ok=True)
18
  os.makedirs("/tmp/fontconfig", exist_ok=True)
19
 
 
 
20
  app = FastAPI()
21
 
22
  UPLOAD_DIR = "/tmp/uploads"
23
  os.makedirs(UPLOAD_DIR, exist_ok=True)
24
+ ALLOWED_EXTENSIONS = {".png", ".jpg", ".jpeg",".pdf"}
25
  app.add_middleware(
26
  CORSMiddleware,
27
  allow_origins=["*"],
 
30
  allow_headers=["*"],
31
  )
32
 
33
+ def resize_to_640(img: Image.Image) -> Image.Image:
34
+ w_percent = 640 / float(img.size[0])
35
+ h_size = int((float(img.size[1]) * float(w_percent)))
36
+ return img.resize((640, h_size), Image.LANCZOS)
37
+
38
+
39
  @app.post("/extract-invoice")
40
  async def extract_invoice(file: UploadFile = File(...)):
41
  file_ext = os.path.splitext(file.filename)[-1].lower()
42
 
43
  if file_ext not in ALLOWED_EXTENSIONS:
44
+ raise HTTPException(status_code=400, detail="Supported formats: .png, .jpg, .jpeg, .pdf")
45
 
 
46
  unique_filename = f"{uuid.uuid4().hex}{file_ext}"
47
+ file_path = os.path.join(UPLOAD_DIR, unique_filename)
48
 
49
  try:
50
+ # Save uploaded file
51
+ with open(file_path, "wb") as f:
52
  shutil.copyfileobj(file.file, f)
53
 
54
+ if file_ext == ".pdf":
55
+ # Convert PDF's first page to image
56
+ images = convert_from_path(file_path, dpi=300)
57
+ img = resize_to_640(images[0])
58
+ image_path = os.path.join(UPLOAD_DIR, f"{uuid.uuid4().hex}.png")
59
+ img.save(image_path)
60
+ else:
61
+ image_path = file_path
62
+
63
+ # Run inference
64
+ extracted_data = extract_invoice_data_from_image(image_path)
65
 
66
  return JSONResponse(content={
67
  "success": True,
 
76
  )
77
 
78
  finally:
79
+ # Clean up temp files
80
+ if os.path.exists(file_path):
81
+ os.remove(file_path)
82
+ if file_ext == ".pdf" and os.path.exists(image_path):
83
+ os.remove(image_path)