Spaces:
Sleeping
Sleeping
| import json | |
| import re | |
| from datetime import datetime | |
| import gradio as gr | |
| import torch | |
| from transformers import AutoTokenizer, AutoModelForQuestionAnswering | |
| # --------- MODELO QA (Kaleidoscope) ---------- | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| qa_model_id = "2KKLabs/Kaleidoscope_small_v1" | |
| tokenizer = AutoTokenizer.from_pretrained(qa_model_id) | |
| model = AutoModelForQuestionAnswering.from_pretrained(qa_model_id) | |
| model.to(device) | |
| model.eval() | |
| TIPOS = [ | |
| "coche", | |
| "comidas", | |
| "envio postal", | |
| "estacionamiento", | |
| "hoteles", | |
| "peaje", | |
| "taxis", | |
| "telefono/celular/internet", | |
| "tren", | |
| "vuelos", | |
| ] | |
| # --------- OCR: imagen -> texto (placeholder) ---------- | |
| def ocr_image_to_text(image): | |
| """ | |
| Sustituye esto por tu OCR real (easyocr, paddleocr, etc.). | |
| De momento devuelve un stub para poder probar el flujo. | |
| """ | |
| return "stub text from OCR with date 2024-11-01 amount 23.50 EUR bar Velodromo comidas" | |
| # --------- Utilidades de post-procesado ---------- | |
| def normalize_date(text): | |
| patterns = [ | |
| r"(\d{4})-(\d{2})-(\d{2})", # 2024-11-01 | |
| r"(\d{2})/(\d{2})/(\d{4})", # 01/11/2024 | |
| r"(\d{2})-(\d{2})-(\d{4})", # 01-11-2024 | |
| ] | |
| for p in patterns: | |
| m = re.search(p, text) | |
| if m: | |
| g = m.groups() | |
| try: | |
| if len(g) == 4: # YYYY-MM-DD | |
| dt = datetime(int(g), int(g[5]), int(g[6])) | |
| else: # DD/MM/YYYY o DD-MM-YYYY | |
| dt = datetime(int(g[6]), int(g[5]), int(g)) | |
| return dt.strftime("%Y-%m-%d") | |
| except Exception: | |
| pass | |
| return "" | |
| def normalize_amount(text): | |
| m = re.search(r"(\d+[.,]\d{2})", text) | |
| if not m: | |
| return "" | |
| return m.group(1).replace(",", ".") | |
| def best_tipo_from_text(text): | |
| t = text.lower() | |
| if "parking" in t or "aparcamiento" in t: | |
| return "estacionamiento" | |
| if "peaje" in t or "toll" in t: | |
| return "peaje" | |
| if "taxi" in t: | |
| return "taxis" | |
| if "hotel" in t: | |
| return "hoteles" | |
| if "train" in t or "renfe" in t or "tren" in t: | |
| return "tren" | |
| if "flight" in t or "vueling" in t or "iberia" in t: | |
| return "vuelos" | |
| if "diesel" in t or "fuel" in t or "gasolina" in t: | |
| return "coche" | |
| if "internet" in t or "movistar" in t or "vodafone" in t: | |
| return "telefono/celular/internet" | |
| return "comidas" | |
| def truncate_desc(desc, max_words=6): | |
| words = desc.split() | |
| if len(words) <= max_words: | |
| return desc | |
| return " ".join(words[:max_words]) | |
| # --------- Llamada al modelo QA ---------- | |
| def qa_answer(context, question, max_length=384): | |
| inputs = tokenizer( | |
| question, | |
| context, | |
| return_tensors="pt", | |
| truncation=True, | |
| max_length=max_length | |
| ) | |
| inputs = {k: v.to(device) for k, v in inputs.items()} | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| start_index = int(torch.argmax(outputs.start_logits)) | |
| end_index = int(torch.argmax(outputs.end_logits)) | |
| answer_tokens = inputs["input_ids"][start_index : end_index + 1] | |
| answer = tokenizer.decode(answer_tokens, skip_special_tokens=True) | |
| return answer.strip() | |
| # --------- Pipeline principal ---------- | |
| def process_receipt(image): | |
| # 1) Imagen -> texto | |
| context = ocr_image_to_text(image) | |
| # 2) Pregunta al modelo para obtener JSON bruto | |
| question = ( | |
| "From this receipt text extract: " | |
| "fecha (date), tipo (one of coche, comidas, envio postal, estacionamiento, hoteles, peaje, " | |
| "taxis, telefono/celular/internet, tren, vuelos), " | |
| "description (<=6 words), amount (numeric), comments (business name). " | |
| "Return only a JSON object with keys: fecha, tipo, description, amount, comments." | |
| ) | |
| raw_answer = qa_answer(context, question) | |
| # 3) Parseo / fallback | |
| fecha = "" | |
| tipo = "" | |
| descripcion = "" | |
| amount = "" | |
| comments = "" | |
| try: | |
| obj = json.loads(raw_answer) | |
| fecha = obj.get("fecha", "") | |
| tipo = obj.get("tipo", "") | |
| descripcion = obj.get("description", "") | |
| amount = str(obj.get("amount", "")) | |
| comments = obj.get("comments", "") | |
| except Exception: | |
| fecha = normalize_date(context) | |
| amount = normalize_amount(context) | |
| tipo = best_tipo_from_text(context) | |
| descripcion = "expense item" | |
| first_line = context.splitlines() if context.splitlines() else "" | |
| comments = first_line[:60] | |
| # 4) Normalizaci贸n | |
| if not fecha: | |
| fecha = normalize_date(context) | |
| if tipo not in TIPOS: | |
| tipo = best_tipo_from_text(context) | |
| descripcion = truncate_desc(descripcion, 6) | |
| try: | |
| amount_val = float(amount) | |
| except Exception: | |
| amount_val = 0.0 | |
| return fecha, tipo, descripcion, amount_val, comments | |
| # --------- Interfaz Gradio ---------- | |
| with gr.Blocks(title="Receiptesting - Kaleidoscope") as demo: | |
| gr.Markdown( | |
| "## Receiptesting con Kaleidoscope_small_v1\n\n" | |
| "Sube una imagen de un recibo y se extraer谩n: **fecha**, **tipo**, " | |
| "**descripci贸n corta**, **amount** y **comentarios (nombre del negocio)**." | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| image_in = gr.Image( | |
| type="pil", | |
| label="Imagen del recibo", | |
| ) | |
| btn = gr.Button("Extraer") | |
| with gr.Column(): | |
| fecha_out = gr.Textbox(label="Fecha (YYYY-MM-DD)") | |
| tipo_out = gr.Dropdown( | |
| choices=TIPOS, | |
| label="Tipo", | |
| ) | |
| desc_out = gr.Textbox(label="Descripci贸n (<= 6 palabras)") | |
| amount_out = gr.Number(label="Amount") | |
| comments_out = gr.Textbox(label="Comentarios (nombre del negocio)") | |
| btn.click( | |
| process_receipt, | |
| inputs=[image_in], | |
| outputs=[fecha_out, tipo_out, desc_out, amount_out, comments_out], | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |