RajanMalaviya commited on
Commit
df4f589
·
verified ·
1 Parent(s): e86531f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +70 -55
app.py CHANGED
@@ -33,20 +33,9 @@ if not hf_token:
33
  logger.error("HF_TOKEN environment variable not set")
34
  raise HTTPException(status_code=500, detail="HF_TOKEN environment variable not set")
35
 
36
- # Initialize Hugging Face Inference Client with primary and fallback models
37
- primary_model = "mistral/Mixtral-8x7B-Instruct-v0.1"
38
- fallback_model = "Qwen/Qwen2-7B-Instruct"
39
- try:
40
- client = InferenceClient(model=primary_model, token=hf_token, provider="auto")
41
- logger.info(f"Hugging Face Inference Client initialized for {primary_model} with provider='auto'")
42
- except Exception as e:
43
- logger.warning(f"Failed to initialize client for {primary_model}: {str(e)}. Falling back to {fallback_model}")
44
- try:
45
- client = InferenceClient(model=fallback_model, token=hf_token, provider="hf-inference")
46
- logger.info(f"Hugging Face Inference Client initialized for {fallback_model} with provider='hf-inference'")
47
- except Exception as e:
48
- logger.error(f"Failed to initialize client for {fallback_model}: {str(e)}")
49
- raise HTTPException(status_code=500, detail=f"Failed to initialize Inference Client: {str(e)}")
50
 
51
  # In-memory caches (1-hour TTL)
52
  raw_text_cache = cachetools.TTLCache(maxsize=100, ttl=3600)
@@ -71,7 +60,7 @@ async def process_image(img_bytes, filename, idx):
71
  start_time = time.time()
72
  logger.info(f"Starting OCR for {filename} image {idx}, {log_memory_usage()}")
73
  try:
74
- img = Image.open(io.BytesIO(img_bytes)).resize((800, 600)) # Resize for faster OCR
75
  img_cv = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
76
  gray = cv2.cvtColor(img_cv, cv2.COLOR_BGR2GRAY)
77
  img_pil = Image.fromarray(cv2.cvtColor(gray, cv2.COLOR_GRAY2RGB))
@@ -88,7 +77,7 @@ async def process_pdf_page(img, page_idx):
88
  start_time = time.time()
89
  logger.info(f"Starting OCR for PDF page {page_idx}, {log_memory_usage()}")
90
  try:
91
- img = img.resize((800, 600)) # Resize for faster OCR
92
  img_cv = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
93
  gray = cv2.cvtColor(img_cv, cv2.COLOR_BGR2GRAY)
94
  img_pil = Image.fromarray(cv2.cvtColor(gray, cv2.COLOR_GRAY2RGB))
@@ -100,7 +89,7 @@ async def process_pdf_page(img, page_idx):
100
  logger.error(f"OCR failed for PDF page {page_idx}: {str(e)}, {log_memory_usage()}")
101
  return ""
102
 
103
- async def process_with_qwen(filename: str, raw_text: str):
104
  """Process raw text with LLM via Hugging Face Inference API."""
105
  start_time = time.time()
106
  logger.info(f"Starting LLM API processing for {filename}, {log_memory_usage()}")
@@ -116,43 +105,69 @@ async def process_with_qwen(filename: str, raw_text: str):
116
  raw_text = raw_text[:2000]
117
  logger.info(f"Truncated raw text for {filename} to 2000 characters, {log_memory_usage()}")
118
 
119
- try:
120
- prompt = f"""
121
- Extract key invoice fields as JSON from the raw text. Support English. Detect currency (e.g., USD, INR). Output only valid JSON, with no additional text, comments, or markdown.
122
- Raw text: {raw_text}
123
- Output JSON:
124
- {{
125
- "currency": "",
126
- "Name_Client": "",
127
- "Products": [],
128
- "Subtotal": "",
129
- "Tax": "",
130
- "total": "",
131
- "invoice date": "",
132
- "invoice number": ""
133
- }}
134
- """
135
- # Call Hugging Face Inference API
136
- response = await asyncio.to_thread(client.chat_completion,
137
- messages=[{"role": "user", "content": prompt}],
138
- max_tokens=256,
139
- temperature=0.7
140
- )
141
- llm_output = response.choices[0].message.content
142
 
143
- # Extract JSON from output
144
- json_start = llm_output.find("{")
145
- json_end = llm_output.rfind("}") + 1
146
- if json_start == -1 or json_end == -1:
147
- raise ValueError("No valid JSON found in API output")
148
- json_str = llm_output[json_start:json_end]
149
- structured_data = json.loads(json_str)
150
- structured_data_cache[text_hash] = structured_data
151
- logger.info(f"LLM API processing for {filename}, took {time.time() - start_time:.2f} seconds, {log_memory_usage()}")
152
- return structured_data
153
- except Exception as e:
154
- logger.error(f"LLM API processing failed for {filename}: {str(e)}, {log_memory_usage()}")
155
- return {"error": f"LLM API processing failed: {str(e)}"}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
156
 
157
  @app.post("/ocr")
158
  async def extract_and_structure(files: List[UploadFile] = File(...)):
@@ -265,10 +280,10 @@ async def extract_and_structure(files: List[UploadFile] = File(...)):
265
  raw_text_cache[file_hash] = raw_text
266
  logger.info(f"Text normalization for {file.filename}, took {time.time() - normalize_start_time:.2f} seconds, text length: {len(raw_text)}, {log_memory_usage()}")
267
  except Exception as e:
268
- logger.warning(f"Text normalization failed for {file.filename}: {str(e)}, {log_memory_usage()}")
269
 
270
  # Process with LLM API
271
- structured_data = await process_with_qwen(file.filename, raw_text)
272
  success_count += 1
273
  output_json["data"].append({
274
  "filename": file.filename,
 
33
  logger.error("HF_TOKEN environment variable not set")
34
  raise HTTPException(status_code=500, detail="HF_TOKEN environment variable not set")
35
 
36
+ # Initialize Hugging Face Inference Client
37
+ client = InferenceClient(token=hf_token)
38
+ logger.info("Hugging Face Inference Client initialized")
 
 
 
 
 
 
 
 
 
 
 
39
 
40
  # In-memory caches (1-hour TTL)
41
  raw_text_cache = cachetools.TTLCache(maxsize=100, ttl=3600)
 
60
  start_time = time.time()
61
  logger.info(f"Starting OCR for {filename} image {idx}, {log_memory_usage()}")
62
  try:
63
+ img = Image.open(io.BytesIO(img_bytes)).resize((600, 400)) # Smaller for speed
64
  img_cv = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
65
  gray = cv2.cvtColor(img_cv, cv2.COLOR_BGR2GRAY)
66
  img_pil = Image.fromarray(cv2.cvtColor(gray, cv2.COLOR_GRAY2RGB))
 
77
  start_time = time.time()
78
  logger.info(f"Starting OCR for PDF page {page_idx}, {log_memory_usage()}")
79
  try:
80
+ img = img.resize((600, 400)) # Smaller for speed
81
  img_cv = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
82
  gray = cv2.cvtColor(img_cv, cv2.COLOR_BGR2GRAY)
83
  img_pil = Image.fromarray(cv2.cvtColor(gray, cv2.COLOR_GRAY2RGB))
 
89
  logger.error(f"OCR failed for PDF page {page_idx}: {str(e)}, {log_memory_usage()}")
90
  return ""
91
 
92
+ async def process_with_llm(filename: str, raw_text: str):
93
  """Process raw text with LLM via Hugging Face Inference API."""
94
  start_time = time.time()
95
  logger.info(f"Starting LLM API processing for {filename}, {log_memory_usage()}")
 
105
  raw_text = raw_text[:2000]
106
  logger.info(f"Truncated raw text for {filename} to 2000 characters, {log_memory_usage()}")
107
 
108
+ # Define models to try with retry logic
109
+ models = [
110
+ {"model": "google/gemma-2-9b-it", "provider": "auto"},
111
+ {"model": "meta-llama/Meta-Llama-3-8B-Instruct", "provider": "auto"}
112
+ ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
113
 
114
+ for model_info in models:
115
+ model = model_info["model"]
116
+ provider = model_info["provider"]
117
+ logger.info(f"Attempting LLM API call with model {model} and provider {provider}")
118
+ for attempt in range(2): # Retry once
119
+ try:
120
+ prompt = f"""
121
+ Extract key invoice fields as JSON from the raw text. Support English. Detect currency (e.g., USD, INR). Output only valid JSON, with no additional text, comments, or markdown.
122
+ Raw text: {raw_text}
123
+ Output JSON:
124
+ {{
125
+ "currency": "",
126
+ "Name_Client": "",
127
+ "Products": [],
128
+ "Subtotal": "",
129
+ "Tax": "",
130
+ "total": "",
131
+ "invoice date": "",
132
+ "invoice number": ""
133
+ }}
134
+ """
135
+ # Call Hugging Face Inference API
136
+ response = await asyncio.to_thread(client.chat_completion,
137
+ model=model,
138
+ messages=[{"role": "user", "content": prompt}],
139
+ max_tokens=256,
140
+ temperature=0.7,
141
+ provider=provider
142
+ )
143
+ llm_output = response.choices[0].message.content
144
+
145
+ # Extract JSON from output
146
+ llm_output = llm_output.strip()
147
+ if not llm_output.startswith("{"):
148
+ raise ValueError("API output is not valid JSON")
149
+ json_start = llm_output.find("{")
150
+ json_end = llm_output.rfind("}") + 1
151
+ json_str = llm_output[json_start:json_end]
152
+ try:
153
+ structured_data = json.loads(json_str)
154
+ except json.JSONDecodeError:
155
+ logger.warning(f"JSON parsing failed for {filename}, attempting to fix")
156
+ json_str = llm_output[llm_output.find("{"):llm_output.rfind("}")+1]
157
+ structured_data = json.loads(json_str)
158
+ structured_data_cache[text_hash] = structured_data
159
+ logger.info(f"LLM API processing for {filename} with {model}, attempt {attempt+1}, took {time.time() - start_time:.2f} seconds, {log_memory_usage()}")
160
+ return structured_data
161
+ except Exception as e:
162
+ logger.warning(f"LLM API processing failed for {filename} with {model}, attempt {attempt+1}: {str(e)}, {log_memory_usage()}")
163
+ if attempt == 1: # No more retries
164
+ break
165
+ await asyncio.sleep(1) # Wait before retry
166
+
167
+ # If all models fail
168
+ error_msg = "All LLM API models failed. Check model availability, authentication, or rate limits."
169
+ logger.error(f"{error_msg} for {filename}, {log_memory_usage()}")
170
+ return {"error": error_msg}
171
 
172
  @app.post("/ocr")
173
  async def extract_and_structure(files: List[UploadFile] = File(...)):
 
280
  raw_text_cache[file_hash] = raw_text
281
  logger.info(f"Text normalization for {file.filename}, took {time.time() - normalize_start_time:.2f} seconds, text length: {len(raw_text)}, {log_memory_usage()}")
282
  except Exception as e:
283
+ logger.warning(f"Text normalization failed for {filename}: {str(e)}, {log_memory_usage()}")
284
 
285
  # Process with LLM API
286
+ structured_data = await process_with_llm(file.filename, raw_text)
287
  success_count += 1
288
  output_json["data"].append({
289
  "filename": file.filename,