Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| from paddleocr import PaddleOCR | |
| from groq import Groq | |
| from openai import OpenAI | |
| import os | |
| import json | |
| ################################## | |
| # Initialize Models | |
| ################################## | |
| print("Loading PaddleOCR model...") | |
| # Available languages in PaddleOCR | |
| AVAILABLE_LANGUAGES = { | |
| 'English': 'en', | |
| 'Chinese Simplified': 'ch', | |
| 'French': 'fr', | |
| 'German': 'german', | |
| 'Korean': 'korean', | |
| 'Japanese': 'japan', | |
| 'Italian': 'it', | |
| 'Spanish': 'es', | |
| 'Portuguese': 'pt', | |
| 'Russian': 'ru', | |
| 'Arabic': 'ar', | |
| 'Hindi': 'hi', | |
| 'Vietnamese': 'vi', | |
| 'Thai': 'th' | |
| } | |
| # Available LLM providers | |
| PROVIDERS = ["None", "Groq", "OpenAI"] | |
| # Dictionary to store OCR models for different languages | |
| ocr_models = {} | |
| def get_ocr_model(lang_code): | |
| if lang_code not in ocr_models: | |
| ocr_models[lang_code] = PaddleOCR( | |
| use_angle_cls=True, | |
| lang=lang_code, | |
| show_log=False, | |
| enable_mkldnn=True # Better CPU performance | |
| ) | |
| return ocr_models[lang_code] | |
| ################################## | |
| # Groq Processing Functions | |
| ################################## | |
| def format_with_groq(text: str, api_key: str) -> str: | |
| client = Groq(api_key=api_key) | |
| completion = client.chat.completions.create( | |
| model="llama3-8b-8192", | |
| messages=[ | |
| { | |
| "role": "system", | |
| "content": ( | |
| "You are a receipt data extraction expert. Extract and format the receipt data into a clear JSON structure.\n" | |
| "Look for these key pieces of information:\n" | |
| "1. Restaurant/store name\n" | |
| "2. Restaurant Address /store address\n" | |
| "3. Date and time\n" | |
| "4. Individual items with quantities and prices\n" | |
| "5. Table number if present\n" | |
| "6. Server name if present\n" | |
| "7. Payment details\n" | |
| "8. Receipt/order number\n" | |
| "Format numbers as actual numbers, not strings." | |
| ) | |
| }, | |
| { | |
| "role": "user", | |
| "content": f"Convert this receipt text to structured data:\n\n{text}" | |
| } | |
| ], | |
| temperature=0.1, | |
| max_tokens=1024, | |
| top_p=1, | |
| stream=True | |
| ) | |
| formatted_text = "" | |
| for chunk in completion: | |
| content = getattr(chunk.choices[0].delta, "content", None) | |
| if content: | |
| formatted_text += content | |
| return formatted_text.strip() | |
| def refine_json_with_groq(initial_text: str, api_key: str) -> str: | |
| client = Groq(api_key=api_key) | |
| completion = client.chat.completions.create( | |
| model="llama3-8b-8192", | |
| messages=[ | |
| { | |
| "role": "system", | |
| "content": ( | |
| "Convert the receipt data into this exact JSON format:\n" | |
| "{\n" | |
| " 'restaurant_name': string,\n" | |
| " 'restaurant_address': string,\n" | |
| " 'date': string,\n" | |
| " 'time': string,\n" | |
| " 'table_number': string or number,\n" | |
| " 'server_name': string,\n" | |
| " 'payment_method': string,\n" | |
| " 'items': [{'name': string, 'quantity': number, 'price': number}],\n" | |
| " 'subtotal': number,\n" | |
| " 'tax': number,\n" | |
| " 'tip': number or null,\n" | |
| " 'total': number,\n" | |
| " 'receipt_number': string or null\n" | |
| "}\n" | |
| "Rules:\n" | |
| "1. Use ONLY double quotes for JSON compliance\n" | |
| "2. All numbers must be actual numbers, not strings\n" | |
| "3. Return ONLY the JSON, no explanations\n" | |
| "4. Ensure math is correct" | |
| ) | |
| }, | |
| { | |
| "role": "user", | |
| "content": f"Format this receipt data as valid JSON:\n\n{initial_text}" | |
| } | |
| ], | |
| temperature=0.1, | |
| max_tokens=1024, | |
| top_p=1, | |
| stream=True | |
| ) | |
| refined_text = "" | |
| for chunk in completion: | |
| content = getattr(chunk.choices[0].delta, "content", None) | |
| if content: | |
| refined_text += content | |
| try: | |
| # Clean up any potential extra text | |
| json_start = refined_text.find('{') | |
| json_end = refined_text.rfind('}') + 1 | |
| if json_start >= 0 and json_end > 0: | |
| refined_text = refined_text[json_start:json_end] | |
| # Validate JSON and reformat | |
| parsed_json = json.loads(refined_text) | |
| return json.dumps(parsed_json, indent=2) | |
| except json.JSONDecodeError: | |
| return refined_text | |
| ################################## | |
| # OpenAI Processing Functions | |
| ################################## | |
| def process_with_openai(text: str, api_key: str) -> dict: | |
| client = OpenAI(api_key=api_key) | |
| try: | |
| completion = client.chat.completions.create( | |
| model="gpt-3.5-turbo", | |
| messages=[ | |
| { | |
| "role": "system", | |
| "content": ( | |
| "Convert the receipt data into this exact JSON format:\n" | |
| "{\n" | |
| " 'restaurant_name': string,\n" | |
| " 'restaurant_address': string,\n" | |
| " 'date': string,\n" | |
| " 'time': string,\n" | |
| " 'table_number': string or number,\n" | |
| " 'server_name': string,\n" | |
| " 'payment_method': string,\n" | |
| " 'items': [{'name': string, 'quantity': number, 'price': number}],\n" | |
| " 'subtotal': number,\n" | |
| " 'tax': number,\n" | |
| " 'tip': number or null,\n" | |
| " 'total': number,\n" | |
| " 'receipt_number': string or null\n" | |
| "}\n" | |
| "Rules:\n" | |
| "1. Use ONLY double quotes for JSON compliance\n" | |
| "2. All numbers must be actual numbers, not strings\n" | |
| "3. Return ONLY the JSON, no explanations" | |
| ) | |
| }, | |
| { | |
| "role": "user", | |
| "content": f"Convert this receipt text to JSON:\n\n{text}" | |
| } | |
| ], | |
| temperature=0.1 | |
| ) | |
| return completion.choices[0].message.content | |
| except Exception as e: | |
| return json.dumps({"error": str(e)}) | |
| ################################## | |
| # Main Processing | |
| ################################## | |
| def process_receipt(image, selected_language, provider="None", api_key=""): | |
| try: | |
| os.makedirs("temp", exist_ok=True) | |
| image_path = os.path.join("temp", "temp_image.jpg") | |
| image.save(image_path) | |
| # Get OCR model and process image | |
| lang_code = AVAILABLE_LANGUAGES[selected_language] | |
| ocr_model = get_ocr_model(lang_code) | |
| result = ocr_model.ocr(image_path, cls=True) | |
| # Extract text from results | |
| extracted_text = "\n".join([line[1][0] for page in result for line in page]) | |
| # If no provider/api key, return raw OCR | |
| if not api_key or provider == "None": | |
| return { | |
| "raw_ocr_text": extracted_text, | |
| "note": "Provide API key and select a provider for structured JSON output" | |
| } | |
| try: | |
| if provider == "Groq": | |
| # Two-step Groq processing | |
| initial_text = format_with_groq(extracted_text, api_key) | |
| final_json = refine_json_with_groq(initial_text, api_key) | |
| return json.loads(final_json) | |
| elif provider == "OpenAI": | |
| # OpenAI processing | |
| result = process_with_openai(extracted_text, api_key) | |
| return json.loads(result) | |
| except json.JSONDecodeError: | |
| return { | |
| "error": "Failed to parse response", | |
| "raw_ocr_text": extracted_text | |
| } | |
| except Exception as e: | |
| return { | |
| "error": str(e), | |
| "type": "processing_error" | |
| } | |
| finally: | |
| if os.path.exists(image_path): | |
| try: | |
| os.remove(image_path) | |
| except: | |
| pass | |
| ################################## | |
| # Gradio Interface | |
| ################################## | |
| css = """ | |
| .gradio-container {max-width: 1100px !important} | |
| """ | |
| with gr.Blocks(css=css) as demo: | |
| gr.Markdown("# Multi-Language Receipt OCR") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| image_input = gr.Image( | |
| type="pil", | |
| label="Upload Receipt Image", | |
| height=400 | |
| ) | |
| language_dropdown = gr.Dropdown( | |
| choices=list(AVAILABLE_LANGUAGES.keys()), | |
| value="English", | |
| label="Select Language", | |
| info="Choose the primary language of the receipt" | |
| ) | |
| with gr.Row(): | |
| provider_dropdown = gr.Dropdown( | |
| choices=PROVIDERS, | |
| value="None", | |
| label="Select LLM Provider", | |
| info="Choose provider for JSON formatting" | |
| ) | |
| api_key_input = gr.Textbox( | |
| label="API Key", | |
| placeholder="Enter your API key", | |
| type="password", | |
| info="Required for JSON formatting" | |
| ) | |
| submit_button = gr.Button("Process Receipt", variant="primary") | |
| with gr.Column(scale=1): | |
| json_output = gr.JSON( | |
| label="Extracted Receipt Data", | |
| height=500 | |
| ) | |
| gr.Markdown(""" | |
| ### Usage Instructions | |
| 1. Upload a clear image of your receipt | |
| 2. Select the receipt's primary language | |
| 3. (Optional) Choose a provider and enter API key for JSON formatting | |
| 4. Click 'Process Receipt' | |
| ### Notes | |
| - Without an API key, you'll receive raw OCR text | |
| - For best results, ensure receipt image is clear and well-lit | |
| - Supported languages include English, Chinese, French, German, and more | |
| """) | |
| submit_button.click( | |
| fn=process_receipt, | |
| inputs=[ | |
| image_input, | |
| language_dropdown, | |
| provider_dropdown, | |
| api_key_input | |
| ], | |
| outputs=[json_output], | |
| ) | |
| # Close any existing gradio instances | |
| gr.close_all() | |
| # Launch the app | |
| demo.queue(max_size=10) | |
| demo.launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| show_api=False, | |
| share=False | |
| ) |