Spaces:
Sleeping
Sleeping
| from paddleocr import PaddleOCR | |
| from gliner import GLiNER | |
| from PIL import Image | |
| import gradio as gr | |
| import numpy as np | |
| import cv2 | |
| import logging | |
| import os | |
| import tempfile | |
| import pandas as pd | |
| import io | |
| import re | |
| import traceback | |
| import zxingcpp # Added zxingcpp for QR decoding | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # Set up GLiNER environment variables | |
| os.environ['GLINER_HOME'] = './gliner_models' | |
| # Load GLiNER model | |
| try: | |
| logger.info("Loading GLiNER model...") | |
| gliner_model = GLiNER.from_pretrained("urchade/gliner_large-v2.1") | |
| except Exception as e: | |
| logger.error("Failed to load GLiNER model") | |
| raise e | |
| # Get a random color (used for drawing bounding boxes, if needed) | |
| def get_random_color(): | |
| return tuple(np.random.randint(0, 256, 3).tolist()) | |
| def scan_qr_code(image): | |
| """ | |
| Attempts to scan a QR code from the given PIL image using zxingcpp. | |
| The image is first saved to a temporary file to be read by zxingcpp. | |
| If the direct decoding fails, the function tries a fallback | |
| where the image is converted based on a default QR color (black) and tolerance. | |
| """ | |
| try: | |
| # Save the PIL image to a temporary file | |
| with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp: | |
| image.save(tmp, format="PNG") | |
| tmp_path = tmp.name | |
| # Convert the saved image to a CV2 image | |
| img_cv = cv2.imread(tmp_path) | |
| # First attempt: direct decoding with zxingcpp | |
| try: | |
| results = zxingcpp.read_barcodes(img_cv) | |
| if results and results[0].text: | |
| return results[0].text.strip() | |
| except Exception as e: | |
| logger.warning(f"Direct zxingcpp decoding failed: {e}") | |
| # Fallback: Process image by converting specific QR colors with default parameters. | |
| default_color = "#000000" # Default QR color assumed (black) | |
| tolerance = 50 # Fixed tolerance value | |
| qr_img = image.convert("RGB") | |
| datas = list(qr_img.getdata()) | |
| newData = [] | |
| # Convert hex default color to an RGB tuple | |
| h1 = default_color.strip("#") | |
| rgb_tup = tuple(int(h1[i:i+2], 16) for i in (0, 2, 4)) | |
| for item in datas: | |
| # Check if the pixel is within the tolerance of the default color | |
| if (item[0] in range(rgb_tup[0]-tolerance, rgb_tup[0]+tolerance) and | |
| item[1] in range(rgb_tup[1]-tolerance, rgb_tup[1]+tolerance) and | |
| item[2] in range(rgb_tup[2]-tolerance, rgb_tup[2]+tolerance)): | |
| newData.append((0, 0, 0)) | |
| else: | |
| newData.append((255, 255, 255)) | |
| qr_img.putdata(newData) | |
| fallback_path = tmp_path + "_converted.png" | |
| qr_img.save(fallback_path) | |
| img_cv = cv2.imread(fallback_path) | |
| try: | |
| results = zxingcpp.read_barcodes(img_cv) | |
| if results and results[0].text: | |
| return results[0].text.strip() | |
| except Exception as e: | |
| logger.error(f"Fallback decoding failed: {e}") | |
| return None | |
| except Exception as e: | |
| logger.error(f"QR scan failed: {str(e)}") | |
| return None | |
| def extract_emails(text): | |
| email_regex = r"\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b" | |
| return re.findall(email_regex, text) | |
| def extract_websites(text): | |
| website_regex = r"\b(?:https?://)?(?:www\.)?([A-Za-z0-9-]+\.[A-Za-z]{2,})(?:/\S*)?\b" | |
| matches = re.findall(website_regex, text) | |
| return [m for m in matches if '@' not in m] | |
| def clean_phone_number(phone): | |
| cleaned = re.sub(r"(?!^\+)[^\d]", "", phone) | |
| if len(cleaned) < 9 or (len(cleaned) == 9 and cleaned.startswith("+")): | |
| return None | |
| return cleaned | |
| def normalize_website(url): | |
| url = url.lower().replace("www.", "").split('/')[0] | |
| if not re.match(r"^[a-z0-9-]+\.[a-z]{2,}$", url): | |
| return None | |
| return f"www.{url}" | |
| def extract_address(ocr_texts): | |
| address_keywords = ["block", "street", "ave", "area", "industrial", "road"] | |
| address_parts = [] | |
| for text in ocr_texts: | |
| if any(kw in text.lower() for kw in address_keywords): | |
| address_parts.append(text) | |
| return " ".join(address_parts) if address_parts else None | |
| def inference(img: Image.Image, confidence): | |
| try: | |
| ocr = PaddleOCR(use_angle_cls=True, lang='en', use_gpu=False, | |
| det_model_dir='./models/det/en', | |
| cls_model_dir='./models/cls/en', | |
| rec_model_dir='./models/rec/en') | |
| img_np = np.array(img) | |
| result = ocr.ocr(img_np, cls=True)[0] | |
| ocr_texts = [line[1][0] for line in result] | |
| ocr_text = " ".join(ocr_texts) | |
| labels = ["person name", "company name", "job title", | |
| "phone number", "email address", "address", | |
| "website"] | |
| entities = gliner_model.predict_entities(ocr_text, labels, threshold=confidence, flat_ner=True) | |
| results = { | |
| "Person Name": [], | |
| "Company Name": [], | |
| "Job Title": [], | |
| "Phone Number": [], | |
| "Email Address": [], | |
| "Address": [], | |
| "Website": [], | |
| "QR Code": [] | |
| } | |
| # Process entities with validation | |
| for entity in entities: | |
| text = entity["text"].strip() | |
| label = entity["label"].lower() | |
| if label == "phone number": | |
| if (cleaned := clean_phone_number(text)): | |
| results["Phone Number"].append(cleaned) | |
| elif label == "email address" and "@" in text: | |
| results["Email Address"].append(text.lower()) | |
| elif label == "website": | |
| if (normalized := normalize_website(text)): | |
| results["Website"].append(normalized) | |
| elif label == "address": | |
| results["Address"].append(text) | |
| elif label == "company name": | |
| results["Company Name"].append(text) | |
| elif label == "person name": | |
| results["Person Name"].append(text) | |
| elif label == "job title": | |
| results["Job Title"].append(text.title()) | |
| # Regex fallbacks | |
| results["Email Address"] += extract_emails(ocr_text) | |
| results["Website"] += [normalize_website(w) for w in extract_websites(ocr_text)] | |
| # Phone number validation | |
| seen_phones = set() | |
| for phone in results["Phone Number"] + re.findall(r'\+\d{8,}|\d{9,}', ocr_text): | |
| if (cleaned := clean_phone_number(phone)) and cleaned not in seen_phones: | |
| results["Phone Number"].append(cleaned) | |
| seen_phones.add(cleaned) | |
| results["Phone Number"] = list(seen_phones) | |
| # Address processing | |
| if not results["Address"]: | |
| if (address := extract_address(ocr_texts)): | |
| results["Address"].append(address) | |
| # Website normalization | |
| seen_websites = set() | |
| final_websites = [] | |
| for web in results["Website"]: | |
| if web and web not in seen_websites: | |
| final_websites.append(web) | |
| seen_websites.add(web) | |
| results["Website"] = final_websites | |
| # Company name fallback | |
| if not results["Company Name"]: | |
| if results["Email Address"]: | |
| domain = results["Email Address"][0].split('@')[-1].split('.')[0] | |
| results["Company Name"].append(domain.title()) | |
| elif results["Website"]: | |
| domain = results["Website"][0].split('.')[1] | |
| results["Company Name"].append(domain.title()) | |
| # Name fallback | |
| if not results["Person Name"]: | |
| for text in ocr_texts: | |
| if re.match(r"^(?:[A-Z][a-z]+\s?){2,}$", text): | |
| results["Person Name"].append(text) | |
| break | |
| # QR Code scanning using the new zxingcpp-based function | |
| if (qr_data := scan_qr_code(img)): | |
| results["QR Code"].append(qr_data) | |
| # Create CSV file containing the results | |
| csv_data = {k: "; ".join(v) for k, v in results.items() if v} | |
| with tempfile.NamedTemporaryFile(suffix=".csv", delete=False, mode="w") as tmp_file: | |
| pd.DataFrame([csv_data]).to_csv(tmp_file, index=False) | |
| csv_path = tmp_file.name | |
| return ocr_text, csv_data, csv_path, "" | |
| except Exception as e: | |
| logger.error(f"Processing failed: {traceback.format_exc()}") | |
| return "", {}, None, f"Error: {str(e)}\n{traceback.format_exc()}" | |
| # Gradio Interface | |
| title = 'Enhanced Business Card Parser' | |
| description = 'Accurate entity extraction with combined AI and regex validation' | |
| if __name__ == '__main__': | |
| demo = gr.Interface( | |
| inference, | |
| [gr.Image(type='pil', label='Upload Business Card'), | |
| gr.Slider(0.1, 1, 0.4, step=0.1, label='Confidence Threshold')], | |
| [gr.Textbox(label="OCR Result"), | |
| gr.JSON(label="Structured Data"), | |
| gr.File(label="Download CSV"), | |
| gr.Textbox(label="Error Log")], | |
| title=title, | |
| description=description, | |
| css=".gr-interface {max-width: 800px !important;}" | |
| ) | |
| demo.launch() | |