Spaces:
Sleeping
Sleeping
File size: 7,797 Bytes
d56c6ae 4972899 d56c6ae 4972899 d56c6ae 4972899 d56c6ae 4972899 d56c6ae 4972899 d56c6ae 4972899 d56c6ae 4972899 d56c6ae 4972899 d56c6ae | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 |
import os
import io
import hashlib
import asyncio
from typing import List, Dict, Any
from fastapi import FastAPI, UploadFile, File
from fastapi.responses import JSONResponse
from fastapi.middleware.cors import CORSMiddleware
from PIL import Image
from src.pdfconverter import pdf_to_images
from src.vision import classify_image
from src.visual_cues import detect_logos_from_bytes
from src.config import (
UPLOAD_DIR,
ALLOWED_EXTENSIONS,
MAX_TOTAL_FILES,
MAX_PDFS,
MAX_IMAGES,
MAX_IMAGE_MB,
MAX_PDF_MB,
MIN_WIDTH,
MIN_HEIGHT,
MAX_WIDTH,
MAX_HEIGHT,
MAX_VISUAL_PAGES,
MAX_LOGOS_PER_PAGE,
MAX_IMAGE_RESIZE,
)
# --------------------------------------------------
# FASTAPI APPLICATION
# --------------------------------------------------
app = FastAPI(title="DocVision API")
# --------------------------------------------------
# CORS (REQUIRED FOR HUGGING FACE)
# --------------------------------------------------
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# --------------------------------------------------
# HEALTH CHECK
# --------------------------------------------------
@app.get("/")
def health() -> Dict[str, str]:
"""Health check endpoint for routing and monitoring."""
return {"status": "ok"}
# --------------------------------------------------
# DIRECTORIES
# --------------------------------------------------
os.makedirs(UPLOAD_DIR, exist_ok=True)
# --------------------------------------------------
# IN-MEMORY CACHES
# --------------------------------------------------
TEXT_CACHE: Dict[str, Dict[str, Any]] = {}
VISUAL_CACHE: Dict[str, Dict[str, Any]] = {}
# --------------------------------------------------
# HELPER FUNCTIONS
# --------------------------------------------------
def file_hash(data: bytes) -> str:
"""Generate a deterministic hash for file contents."""
return hashlib.md5(data).hexdigest()
def read_file(file: UploadFile) -> bytes:
"""Read file contents without consuming the stream."""
data = file.file.read()
file.file.seek(0)
return data
def validate_file(file: UploadFile, contents: bytes) -> str | None:
"""
Validate file type, size, and image resolution.
Returns an error message if invalid, otherwise None.
"""
ext = os.path.splitext(file.filename)[1].lower()
size_mb = len(contents) / (1024 * 1024)
if ext not in ALLOWED_EXTENSIONS:
return "Unsupported file format"
if ext == ".pdf" and size_mb > MAX_PDF_MB:
return f"PDF exceeds {MAX_PDF_MB} MB"
if ext != ".pdf" and size_mb > MAX_IMAGE_MB:
return f"Image exceeds {MAX_IMAGE_MB} MB"
if ext != ".pdf":
try:
image = Image.open(io.BytesIO(contents))
width, height = image.size
if width < MIN_WIDTH or height < MIN_HEIGHT:
return f"Image too small ({width}x{height})"
if width > MAX_WIDTH or height > MAX_HEIGHT:
return f"Image too large ({width}x{height})"
except Exception:
return "Invalid image file"
return None
# --------------------------------------------------
# DOCUMENT ANALYSIS ENDPOINT
# --------------------------------------------------
@app.post("/analyze")
async def analyze(files: List[UploadFile] = File(...)) -> JSONResponse:
"""
Perform OCR + Vision-based document classification.
"""
if len(files) > MAX_TOTAL_FILES:
return JSONResponse(
{"error": f"Maximum {MAX_TOTAL_FILES} files allowed"},
status_code=400,
)
pdf_count = sum(f.filename.lower().endswith(".pdf") for f in files)
img_count = len(files) - pdf_count
async def process_file(file: UploadFile) -> Dict[str, Any]:
contents = read_file(file)
fid = f"{file.filename}_{file_hash(contents)}"
if file.filename.lower().endswith(".pdf") and pdf_count > MAX_PDFS:
return {"file": file.filename, "error": f"Maximum {MAX_PDFS} PDFs allowed"}
if not file.filename.lower().endswith(".pdf") and img_count > MAX_IMAGES:
return {"file": file.filename, "error": f"Maximum {MAX_IMAGES} images allowed"}
if fid in TEXT_CACHE:
return TEXT_CACHE[fid]
error = validate_file(file, contents)
if error:
return {"file": file.filename, "error": error}
path = os.path.join(UPLOAD_DIR, file.filename)
with open(path, "wb") as f:
f.write(contents)
try:
if file.filename.lower().endswith(".pdf"):
pdf_name = await asyncio.to_thread(pdf_to_images, path)
base_dir = os.path.join("uploads", "images", pdf_name)
first_page = sorted(os.listdir(base_dir))[0]
analysis = await classify_image(os.path.join(base_dir, first_page))
else:
analysis = await classify_image(path)
result = {
"file": file.filename,
"document_type": analysis.get("document_type"),
"reasoning": analysis.get("reasoning"),
"extracted_textfields": analysis.get("extracted_textfields", {}),
}
TEXT_CACHE[fid] = result
return result
except Exception as exc:
return {"file": file.filename, "error": f"Processing failed: {exc}"}
results = await asyncio.gather(*[process_file(f) for f in files])
return JSONResponse(content=results)
# --------------------------------------------------
# VISUAL CUES ENDPOINT
# --------------------------------------------------
@app.post("/visual_cues")
async def visual_cues(files: List[UploadFile] = File(...)) -> JSONResponse:
"""
Detect logos, seals, and visual symbols from documents.
"""
async def process_visual(file: UploadFile) -> Dict[str, Any]:
contents = read_file(file)
fid = f"{file.filename}_{file_hash(contents)}"
if fid in VISUAL_CACHE:
return VISUAL_CACHE[fid]
error = validate_file(file, contents)
if error:
return {"file": file.filename, "error": error}
path = os.path.join(UPLOAD_DIR, file.filename)
with open(path, "wb") as f:
f.write(contents)
visuals = []
try:
if file.filename.lower().endswith(".pdf"):
pdf_name = await asyncio.to_thread(pdf_to_images, path)
base_dir = os.path.join("uploads", "images", pdf_name)
for img_name in sorted(os.listdir(base_dir))[:MAX_VISUAL_PAGES]:
with open(os.path.join(base_dir, img_name), "rb") as img_file:
logos = await asyncio.to_thread(
detect_logos_from_bytes,
img_file.read(),
MAX_IMAGE_RESIZE,
MAX_LOGOS_PER_PAGE,
)
visuals.append({"page": img_name, "logos": logos})
else:
logos = await asyncio.to_thread(
detect_logos_from_bytes,
contents,
MAX_IMAGE_RESIZE,
MAX_LOGOS_PER_PAGE,
)
visuals.append({"page": "image", "logos": logos})
result = {"file": file.filename, "visual_cues": visuals}
VISUAL_CACHE[fid] = result
return result
except Exception as exc:
return {"file": file.filename, "error": f"Visual processing failed: {exc}"}
results = await asyncio.gather(*[process_visual(f) for f in files])
return JSONResponse(content=results)
|