ReceiptSplitAI / qr_retriever.py
valentynliubchenko
merging
eba303d
import io
from datetime import datetime
from typing import Tuple, List, Dict, Optional
import numpy as np
import requests
from PIL import Image
from loguru import logger as log
from qreader import QReader
SLOVAKIA = "Slovakia"
EUR = "EUR"
qreader = QReader(model_size='m')
# =============================== Main Methods ===============================
def get_receipt_by_qr(image_path: str) -> Tuple[Optional[Dict], Optional[Dict]]:
"""
Return converted receipt recognized from QR code
Args:
image_path (str): path to input image
Returns:
Tuple[Dict, Dict]: Tuple of original JSON from API and receipt in appropriate format,
Tuple[None, None]: In case if QR code is not recognized
"""
image = get_png_image(image_path)
decoded_text = qreader.detect_and_decode(image=image)
if decoded_text[0]:
original_json, receipt = get_receipt_by_uid(decoded_text[0])
return original_json, receipt
else:
log.error("QR code not found")
return None, None
def get_receipt_by_uid(receipt_uid: str) -> Tuple[Dict, Dict]:
"""
Return receipt gotten from API and converted
Args:
receipt_uid (str): uid of the receipt
Returns:
Tuple[Dict, Dict]: Tuple of original JSON from API and receipt in appropriate format
"""
original_json = get_receipt_from_api(receipt_uid)
receipt = convert_receipt(original_json)
return original_json, receipt
# =============================== Utils ===============================
def get_receipt_from_api(receipt_uid: str) -> Dict:
"""
Sends request to Slovenia financial API
Args:
receipt_uid (str): uid of the receipt
Returns:
Dict: JSON from API,
None: In case of errors
"""
url = "https://ekasa.financnasprava.sk/mdu/api/v1/opd/receipt/find"
headers = {
"Content-Type": "application/json",
"Accept": "application/json"
}
data = {
"receiptId": receipt_uid
}
try:
response = requests.post(url, json=data, headers=headers)
response.raise_for_status()
return response.json()
except requests.exceptions.HTTPError as http_err:
print(f"HTTP error occurred: {http_err}")
except Exception as err:
print(f"Other error occurred: {err}")
def convert_receipt(response: Dict) -> Dict:
"""
Converts API response to appropriate receipt format
Args:
response (str): response from API
Returns:
Dict: JSON from API,
None: In case of errors
"""
rec_receipt = response.get("receipt")
rec_items = rec_receipt.get("items")
items = convert_items(rec_items)
taxes = get_taxes(rec_receipt)
unit = rec_receipt.get("unit")
subtotal = calculate_subtotal(items)
total_price = rec_receipt.get("totalPrice")
receipt = {
"store_name": rec_receipt.get("organization").get("name"),
"country": SLOVAKIA,
"receipt_type": None,
"store_address": get_address(unit),
"date_time": change_date_format(rec_receipt.get("createDate")),
"currency": EUR,
"sub_total_amount": subtotal,
"total_amount": total_price,
"total_discount": total_price - subtotal,
"all_items_price_with_tax": True,
"payment_method": None,
"rounding": None,
"tax": calculate_tax(rec_receipt),
"taxes_not_included_sum": 0.0, # Field only for American\Canadian receipts
"tips": None,
"items": items,
"taxs_items": taxes
}
return receipt
def convert_items(recognized_items: Dict) -> List[Dict]:
"""
Converts items from API to appropriate format
Args:
recognized_items (Dict): Dictionary that represents items gotten from API
Returns:
List[Dict]: Array with items in appropriate format
"""
items = []
for rec_item in recognized_items:
price = round(float(rec_item.get("price", 0)), 2)
quantity = round(float(rec_item.get("quantity", 0)), 2)
item = {
"name": rec_item.get("name"),
"unit_price": round(price / quantity, 2),
"quantity": quantity,
"unit_of_measurement": None,
"total_price": price,
"discount": None,
"category": None,
"item_price_with_tax": True
}
items.append(item)
return items
def get_taxes(recognized_receipt: Dict) -> List[Dict]:
"""
Returns list of taxes (Base and Reduced. Both or single by existence)
Args:
recognized_receipt (Dict): Dictionary that represents receipt gotten from API
Returns:
List[Dict]: Array with Base or/and Reduced taxes
"""
receipt = recognized_receipt
tax_types = ["Basic", "Reduced"]
taxes = []
for tax_type in tax_types:
vat_rate = receipt.get(f"vatRate{tax_type}")
if vat_rate:
tax_base = receipt.get(f"taxBase{tax_type}")
vat_amount = receipt.get(f"vatAmount{tax_type}")
tax_base = round(tax_base, 2) if tax_base is not None else 0.0
vat_amount = round(vat_amount, 2) if vat_amount is not None else 0.0
tax_info = {
"tax_name": f'${vat_rate}%',
"percentage": float(vat_rate),
"tax_from_amount": tax_base,
"tax": vat_amount,
"total": round(tax_base + vat_amount, 2),
"tax_included": True,
}
taxes.append(tax_info)
return taxes
def calculate_tax(recognized_receipt: Dict) -> float:
"""
Returns sum of Base and Reduced tax amount
Args:
recognized_receipt (Dict): Dictionary that represents receipt gotten from API
Returns:
float: Summary tax
"""
receipt = recognized_receipt
tax_types = ["Basic", "Reduced"]
tax = .0
for tax_type in tax_types:
vat_amount = receipt.get(f"vatAmount{tax_type}")
if vat_amount:
tax += vat_amount
return tax
def get_address(unit: Dict) -> str:
"""
Returns full store address
Args:
unit (Dict): Dictionary that represents cash register
Returns:
str: Address in format '{city}, {street_name} {property_number}, {postal_code}'
"""
return f'{unit.get("municipality")}, {unit.get("streetName")} {unit.get("propertyRegistrationNumber")}, {unit.get("postalCode")}'
def calculate_subtotal(items: List[Dict]) -> float:
"""
Calculate the subtotal of items' total prices.
Args:
items (List[Dict]): List of items, where each item is a dictionary with a 'total_price' key.
Returns:
float: The subtotal of all item total prices.
"""
subtotal = .0
for item in items:
subtotal += float(item.get("total_price"))
return subtotal
def get_png_image(img_path: str) -> np.ndarray:
image = Image.open(img_path)
png_buffer = io.BytesIO()
image.save(png_buffer, format='PNG')
png_buffer.seek(0)
png_image = Image.open(png_buffer)
return np.array(png_image)
def change_date_format(date_str: str) -> str:
dt = datetime.strptime(date_str, '%d.%m.%Y %H:%M:%S')
return dt.strftime('%Y.%m.%d %H:%M:%S')