Update app.py
Browse files
app.py
CHANGED
|
@@ -33,20 +33,9 @@ if not hf_token:
|
|
| 33 |
logger.error("HF_TOKEN environment variable not set")
|
| 34 |
raise HTTPException(status_code=500, detail="HF_TOKEN environment variable not set")
|
| 35 |
|
| 36 |
-
# Initialize Hugging Face Inference Client
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
try:
|
| 40 |
-
client = InferenceClient(model=primary_model, token=hf_token, provider="auto")
|
| 41 |
-
logger.info(f"Hugging Face Inference Client initialized for {primary_model} with provider='auto'")
|
| 42 |
-
except Exception as e:
|
| 43 |
-
logger.warning(f"Failed to initialize client for {primary_model}: {str(e)}. Falling back to {fallback_model}")
|
| 44 |
-
try:
|
| 45 |
-
client = InferenceClient(model=fallback_model, token=hf_token, provider="hf-inference")
|
| 46 |
-
logger.info(f"Hugging Face Inference Client initialized for {fallback_model} with provider='hf-inference'")
|
| 47 |
-
except Exception as e:
|
| 48 |
-
logger.error(f"Failed to initialize client for {fallback_model}: {str(e)}")
|
| 49 |
-
raise HTTPException(status_code=500, detail=f"Failed to initialize Inference Client: {str(e)}")
|
| 50 |
|
| 51 |
# In-memory caches (1-hour TTL)
|
| 52 |
raw_text_cache = cachetools.TTLCache(maxsize=100, ttl=3600)
|
|
@@ -71,7 +60,7 @@ async def process_image(img_bytes, filename, idx):
|
|
| 71 |
start_time = time.time()
|
| 72 |
logger.info(f"Starting OCR for {filename} image {idx}, {log_memory_usage()}")
|
| 73 |
try:
|
| 74 |
-
img = Image.open(io.BytesIO(img_bytes)).resize((
|
| 75 |
img_cv = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
|
| 76 |
gray = cv2.cvtColor(img_cv, cv2.COLOR_BGR2GRAY)
|
| 77 |
img_pil = Image.fromarray(cv2.cvtColor(gray, cv2.COLOR_GRAY2RGB))
|
|
@@ -88,7 +77,7 @@ async def process_pdf_page(img, page_idx):
|
|
| 88 |
start_time = time.time()
|
| 89 |
logger.info(f"Starting OCR for PDF page {page_idx}, {log_memory_usage()}")
|
| 90 |
try:
|
| 91 |
-
img = img.resize((
|
| 92 |
img_cv = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
|
| 93 |
gray = cv2.cvtColor(img_cv, cv2.COLOR_BGR2GRAY)
|
| 94 |
img_pil = Image.fromarray(cv2.cvtColor(gray, cv2.COLOR_GRAY2RGB))
|
|
@@ -100,7 +89,7 @@ async def process_pdf_page(img, page_idx):
|
|
| 100 |
logger.error(f"OCR failed for PDF page {page_idx}: {str(e)}, {log_memory_usage()}")
|
| 101 |
return ""
|
| 102 |
|
| 103 |
-
async def
|
| 104 |
"""Process raw text with LLM via Hugging Face Inference API."""
|
| 105 |
start_time = time.time()
|
| 106 |
logger.info(f"Starting LLM API processing for {filename}, {log_memory_usage()}")
|
|
@@ -116,43 +105,69 @@ async def process_with_qwen(filename: str, raw_text: str):
|
|
| 116 |
raw_text = raw_text[:2000]
|
| 117 |
logger.info(f"Truncated raw text for {filename} to 2000 characters, {log_memory_usage()}")
|
| 118 |
|
| 119 |
-
try
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
{{
|
| 125 |
-
"currency": "",
|
| 126 |
-
"Name_Client": "",
|
| 127 |
-
"Products": [],
|
| 128 |
-
"Subtotal": "",
|
| 129 |
-
"Tax": "",
|
| 130 |
-
"total": "",
|
| 131 |
-
"invoice date": "",
|
| 132 |
-
"invoice number": ""
|
| 133 |
-
}}
|
| 134 |
-
"""
|
| 135 |
-
# Call Hugging Face Inference API
|
| 136 |
-
response = await asyncio.to_thread(client.chat_completion,
|
| 137 |
-
messages=[{"role": "user", "content": prompt}],
|
| 138 |
-
max_tokens=256,
|
| 139 |
-
temperature=0.7
|
| 140 |
-
)
|
| 141 |
-
llm_output = response.choices[0].message.content
|
| 142 |
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 156 |
|
| 157 |
@app.post("/ocr")
|
| 158 |
async def extract_and_structure(files: List[UploadFile] = File(...)):
|
|
@@ -265,10 +280,10 @@ async def extract_and_structure(files: List[UploadFile] = File(...)):
|
|
| 265 |
raw_text_cache[file_hash] = raw_text
|
| 266 |
logger.info(f"Text normalization for {file.filename}, took {time.time() - normalize_start_time:.2f} seconds, text length: {len(raw_text)}, {log_memory_usage()}")
|
| 267 |
except Exception as e:
|
| 268 |
-
logger.warning(f"Text normalization failed for {
|
| 269 |
|
| 270 |
# Process with LLM API
|
| 271 |
-
structured_data = await
|
| 272 |
success_count += 1
|
| 273 |
output_json["data"].append({
|
| 274 |
"filename": file.filename,
|
|
|
|
| 33 |
logger.error("HF_TOKEN environment variable not set")
|
| 34 |
raise HTTPException(status_code=500, detail="HF_TOKEN environment variable not set")
|
| 35 |
|
| 36 |
+
# Initialize Hugging Face Inference Client
|
| 37 |
+
client = InferenceClient(token=hf_token)
|
| 38 |
+
logger.info("Hugging Face Inference Client initialized")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 39 |
|
| 40 |
# In-memory caches (1-hour TTL)
|
| 41 |
raw_text_cache = cachetools.TTLCache(maxsize=100, ttl=3600)
|
|
|
|
| 60 |
start_time = time.time()
|
| 61 |
logger.info(f"Starting OCR for {filename} image {idx}, {log_memory_usage()}")
|
| 62 |
try:
|
| 63 |
+
img = Image.open(io.BytesIO(img_bytes)).resize((600, 400)) # Smaller for speed
|
| 64 |
img_cv = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
|
| 65 |
gray = cv2.cvtColor(img_cv, cv2.COLOR_BGR2GRAY)
|
| 66 |
img_pil = Image.fromarray(cv2.cvtColor(gray, cv2.COLOR_GRAY2RGB))
|
|
|
|
| 77 |
start_time = time.time()
|
| 78 |
logger.info(f"Starting OCR for PDF page {page_idx}, {log_memory_usage()}")
|
| 79 |
try:
|
| 80 |
+
img = img.resize((600, 400)) # Smaller for speed
|
| 81 |
img_cv = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
|
| 82 |
gray = cv2.cvtColor(img_cv, cv2.COLOR_BGR2GRAY)
|
| 83 |
img_pil = Image.fromarray(cv2.cvtColor(gray, cv2.COLOR_GRAY2RGB))
|
|
|
|
| 89 |
logger.error(f"OCR failed for PDF page {page_idx}: {str(e)}, {log_memory_usage()}")
|
| 90 |
return ""
|
| 91 |
|
| 92 |
+
async def process_with_llm(filename: str, raw_text: str):
|
| 93 |
"""Process raw text with LLM via Hugging Face Inference API."""
|
| 94 |
start_time = time.time()
|
| 95 |
logger.info(f"Starting LLM API processing for {filename}, {log_memory_usage()}")
|
|
|
|
| 105 |
raw_text = raw_text[:2000]
|
| 106 |
logger.info(f"Truncated raw text for {filename} to 2000 characters, {log_memory_usage()}")
|
| 107 |
|
| 108 |
+
# Define models to try with retry logic
|
| 109 |
+
models = [
|
| 110 |
+
{"model": "google/gemma-2-9b-it", "provider": "auto"},
|
| 111 |
+
{"model": "meta-llama/Meta-Llama-3-8B-Instruct", "provider": "auto"}
|
| 112 |
+
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 113 |
|
| 114 |
+
for model_info in models:
|
| 115 |
+
model = model_info["model"]
|
| 116 |
+
provider = model_info["provider"]
|
| 117 |
+
logger.info(f"Attempting LLM API call with model {model} and provider {provider}")
|
| 118 |
+
for attempt in range(2): # Retry once
|
| 119 |
+
try:
|
| 120 |
+
prompt = f"""
|
| 121 |
+
Extract key invoice fields as JSON from the raw text. Support English. Detect currency (e.g., USD, INR). Output only valid JSON, with no additional text, comments, or markdown.
|
| 122 |
+
Raw text: {raw_text}
|
| 123 |
+
Output JSON:
|
| 124 |
+
{{
|
| 125 |
+
"currency": "",
|
| 126 |
+
"Name_Client": "",
|
| 127 |
+
"Products": [],
|
| 128 |
+
"Subtotal": "",
|
| 129 |
+
"Tax": "",
|
| 130 |
+
"total": "",
|
| 131 |
+
"invoice date": "",
|
| 132 |
+
"invoice number": ""
|
| 133 |
+
}}
|
| 134 |
+
"""
|
| 135 |
+
# Call Hugging Face Inference API
|
| 136 |
+
response = await asyncio.to_thread(client.chat_completion,
|
| 137 |
+
model=model,
|
| 138 |
+
messages=[{"role": "user", "content": prompt}],
|
| 139 |
+
max_tokens=256,
|
| 140 |
+
temperature=0.7,
|
| 141 |
+
provider=provider
|
| 142 |
+
)
|
| 143 |
+
llm_output = response.choices[0].message.content
|
| 144 |
+
|
| 145 |
+
# Extract JSON from output
|
| 146 |
+
llm_output = llm_output.strip()
|
| 147 |
+
if not llm_output.startswith("{"):
|
| 148 |
+
raise ValueError("API output is not valid JSON")
|
| 149 |
+
json_start = llm_output.find("{")
|
| 150 |
+
json_end = llm_output.rfind("}") + 1
|
| 151 |
+
json_str = llm_output[json_start:json_end]
|
| 152 |
+
try:
|
| 153 |
+
structured_data = json.loads(json_str)
|
| 154 |
+
except json.JSONDecodeError:
|
| 155 |
+
logger.warning(f"JSON parsing failed for {filename}, attempting to fix")
|
| 156 |
+
json_str = llm_output[llm_output.find("{"):llm_output.rfind("}")+1]
|
| 157 |
+
structured_data = json.loads(json_str)
|
| 158 |
+
structured_data_cache[text_hash] = structured_data
|
| 159 |
+
logger.info(f"LLM API processing for {filename} with {model}, attempt {attempt+1}, took {time.time() - start_time:.2f} seconds, {log_memory_usage()}")
|
| 160 |
+
return structured_data
|
| 161 |
+
except Exception as e:
|
| 162 |
+
logger.warning(f"LLM API processing failed for {filename} with {model}, attempt {attempt+1}: {str(e)}, {log_memory_usage()}")
|
| 163 |
+
if attempt == 1: # No more retries
|
| 164 |
+
break
|
| 165 |
+
await asyncio.sleep(1) # Wait before retry
|
| 166 |
+
|
| 167 |
+
# If all models fail
|
| 168 |
+
error_msg = "All LLM API models failed. Check model availability, authentication, or rate limits."
|
| 169 |
+
logger.error(f"{error_msg} for {filename}, {log_memory_usage()}")
|
| 170 |
+
return {"error": error_msg}
|
| 171 |
|
| 172 |
@app.post("/ocr")
|
| 173 |
async def extract_and_structure(files: List[UploadFile] = File(...)):
|
|
|
|
| 280 |
raw_text_cache[file_hash] = raw_text
|
| 281 |
logger.info(f"Text normalization for {file.filename}, took {time.time() - normalize_start_time:.2f} seconds, text length: {len(raw_text)}, {log_memory_usage()}")
|
| 282 |
except Exception as e:
|
| 283 |
+
logger.warning(f"Text normalization failed for {filename}: {str(e)}, {log_memory_usage()}")
|
| 284 |
|
| 285 |
# Process with LLM API
|
| 286 |
+
structured_data = await process_with_llm(file.filename, raw_text)
|
| 287 |
success_count += 1
|
| 288 |
output_json["data"].append({
|
| 289 |
"filename": file.filename,
|