import json import re from datetime import datetime import gradio as gr import torch from transformers import AutoTokenizer, AutoModelForQuestionAnswering # --------- MODELO QA (Kaleidoscope) ---------- device = torch.device("cuda" if torch.cuda.is_available() else "cpu") qa_model_id = "2KKLabs/Kaleidoscope_small_v1" tokenizer = AutoTokenizer.from_pretrained(qa_model_id) model = AutoModelForQuestionAnswering.from_pretrained(qa_model_id) model.to(device) model.eval() TIPOS = [ "coche", "comidas", "envio postal", "estacionamiento", "hoteles", "peaje", "taxis", "telefono/celular/internet", "tren", "vuelos", ] # --------- OCR: imagen -> texto (placeholder) ---------- def ocr_image_to_text(image): """ Sustituye esto por tu OCR real (easyocr, paddleocr, etc.). De momento devuelve un stub para poder probar el flujo. """ return "stub text from OCR with date 2024-11-01 amount 23.50 EUR bar Velodromo comidas" # --------- Utilidades de post-procesado ---------- def normalize_date(text): patterns = [ r"(\d{4})-(\d{2})-(\d{2})", # 2024-11-01 r"(\d{2})/(\d{2})/(\d{4})", # 01/11/2024 r"(\d{2})-(\d{2})-(\d{4})", # 01-11-2024 ] for p in patterns: m = re.search(p, text) if m: g = m.groups() try: if len(g) == 4: # YYYY-MM-DD dt = datetime(int(g), int(g[5]), int(g[6])) else: # DD/MM/YYYY o DD-MM-YYYY dt = datetime(int(g[6]), int(g[5]), int(g)) return dt.strftime("%Y-%m-%d") except Exception: pass return "" def normalize_amount(text): m = re.search(r"(\d+[.,]\d{2})", text) if not m: return "" return m.group(1).replace(",", ".") def best_tipo_from_text(text): t = text.lower() if "parking" in t or "aparcamiento" in t: return "estacionamiento" if "peaje" in t or "toll" in t: return "peaje" if "taxi" in t: return "taxis" if "hotel" in t: return "hoteles" if "train" in t or "renfe" in t or "tren" in t: return "tren" if "flight" in t or "vueling" in t or "iberia" in t: return "vuelos" if "diesel" in t or "fuel" in t or "gasolina" in t: return "coche" if "internet" in t or "movistar" in t or "vodafone" in t: return "telefono/celular/internet" return "comidas" def truncate_desc(desc, max_words=6): words = desc.split() if len(words) <= max_words: return desc return " ".join(words[:max_words]) # --------- Llamada al modelo QA ---------- def qa_answer(context, question, max_length=384): inputs = tokenizer( question, context, return_tensors="pt", truncation=True, max_length=max_length ) inputs = {k: v.to(device) for k, v in inputs.items()} with torch.no_grad(): outputs = model(**inputs) start_index = int(torch.argmax(outputs.start_logits)) end_index = int(torch.argmax(outputs.end_logits)) answer_tokens = inputs["input_ids"][start_index : end_index + 1] answer = tokenizer.decode(answer_tokens, skip_special_tokens=True) return answer.strip() # --------- Pipeline principal ---------- def process_receipt(image): # 1) Imagen -> texto context = ocr_image_to_text(image) # 2) Pregunta al modelo para obtener JSON bruto question = ( "From this receipt text extract: " "fecha (date), tipo (one of coche, comidas, envio postal, estacionamiento, hoteles, peaje, " "taxis, telefono/celular/internet, tren, vuelos), " "description (<=6 words), amount (numeric), comments (business name). " "Return only a JSON object with keys: fecha, tipo, description, amount, comments." ) raw_answer = qa_answer(context, question) # 3) Parseo / fallback fecha = "" tipo = "" descripcion = "" amount = "" comments = "" try: obj = json.loads(raw_answer) fecha = obj.get("fecha", "") tipo = obj.get("tipo", "") descripcion = obj.get("description", "") amount = str(obj.get("amount", "")) comments = obj.get("comments", "") except Exception: fecha = normalize_date(context) amount = normalize_amount(context) tipo = best_tipo_from_text(context) descripcion = "expense item" first_line = context.splitlines() if context.splitlines() else "" comments = first_line[:60] # 4) Normalización if not fecha: fecha = normalize_date(context) if tipo not in TIPOS: tipo = best_tipo_from_text(context) descripcion = truncate_desc(descripcion, 6) try: amount_val = float(amount) except Exception: amount_val = 0.0 return fecha, tipo, descripcion, amount_val, comments # --------- Interfaz Gradio ---------- with gr.Blocks(title="Receiptesting - Kaleidoscope") as demo: gr.Markdown( "## Receiptesting con Kaleidoscope_small_v1\n\n" "Sube una imagen de un recibo y se extraerán: **fecha**, **tipo**, " "**descripción corta**, **amount** y **comentarios (nombre del negocio)**." ) with gr.Row(): with gr.Column(): image_in = gr.Image( type="pil", label="Imagen del recibo", ) btn = gr.Button("Extraer") with gr.Column(): fecha_out = gr.Textbox(label="Fecha (YYYY-MM-DD)") tipo_out = gr.Dropdown( choices=TIPOS, label="Tipo", ) desc_out = gr.Textbox(label="Descripción (<= 6 palabras)") amount_out = gr.Number(label="Amount") comments_out = gr.Textbox(label="Comentarios (nombre del negocio)") btn.click( process_receipt, inputs=[image_in], outputs=[fecha_out, tipo_out, desc_out, amount_out, comments_out], ) if __name__ == "__main__": demo.launch()