Rhulli's picture
Update app.py
c578e92 verified
import os
import re
import unicodedata
import io
import torch
import gradio as gr
import pdfplumber
import pandas as pd
from transformers import (
AutoTokenizer,
AutoModelForTokenClassification,
AutoModelForCausalLM,
BitsAndBytesConfig,
)
from peft import PeftModel
# =========================
# Utilidades de normalización
# =========================
_SPACE_VARIANTS = r"[\u202f\u00a0\u2009\u200a\u2060]"
def _normalise_apostrophes(text: str) -> str:
return text.replace("´", "'").replace("’", "'")
def _normalise_spaces(text: str, collapse: bool = True) -> str:
text = re.sub(_SPACE_VARIANTS, " ", text)
text = unicodedata.normalize("NFKC", text)
if collapse:
text = re.sub(r"[ ]{2,}", " ", text)
return text.strip()
def _clean_timex(ent: str) -> str:
ent = ent.replace("</s>", "").strip()
return re.sub(r"[\.]+$", "", ent)
# =========================
# Identificadores de modelos
# =========================
NER_ID = "Rhulli/Roberta-ner-temporal-expresions-secondtrain"
ID2LABEL = {0: "O", 1: "B-TIMEX", 2: "I-TIMEX"}
BASE_ID = "google/gemma-2b-it"
ADAPTER_ID = "Rhulli/gemma-2b-it-TIMEX3"
# =========================
# Cuantización 4-bit (NF4)
# =========================
quant_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.float16,
)
# =========================
# Token de HF (si lo usas privado)
# =========================
HF_TOKEN = os.getenv("HF_TOKEN")
# =========================
# Carga de modelos
# =========================
def load_models():
# --- NER ---
ner_tok = AutoTokenizer.from_pretrained(NER_ID, token=HF_TOKEN)
ner_mod = AutoModelForTokenClassification.from_pretrained(NER_ID, token=HF_TOKEN)
ner_mod.eval()
if torch.cuda.is_available():
ner_mod.to("cuda")
# --- Base Causal LM (Gemma 2B-it) con 4-bit ---
base_mod = AutoModelForCausalLM.from_pretrained(
BASE_ID,
token=HF_TOKEN,
device_map="auto", # deja a Accelerate decidir
quantization_config=quant_config, # aplica 4-bit NF4
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
)
# --- Tokenizer del BASE (no del adapter) ---
norm_tok = AutoTokenizer.from_pretrained(BASE_ID, use_fast=True, token=HF_TOKEN)
# Asegurar pad_token si falta
if norm_tok.pad_token is None and norm_tok.eos_token is not None:
norm_tok.pad_token = norm_tok.eos_token
# --- Inyectar el LoRA SIN device_map (evitar meta/offload issues) ---
norm_mod = PeftModel.from_pretrained(
base_mod,
ADAPTER_ID,
token=HF_TOKEN,
is_trainable=False,
offload_state_dict=False,
)
norm_mod.eval()
return ner_tok, ner_mod, norm_tok, norm_mod
# Carga inicial de los modelos
ner_tok, ner_mod, norm_tok, norm_mod = load_models()
# Determinar eos_id de manera segura
try:
eos_id = norm_tok.convert_tokens_to_ids("<end_of_turn>")
if eos_id is None or eos_id == norm_tok.unk_token_id:
eos_id = norm_tok.eos_token_id
except Exception:
eos_id = norm_tok.eos_token_id
# =========================
# Lectura de archivos (.txt, .pdf)
# =========================
def read_file(file_obj) -> str:
path = file_obj.name
if path.lower().endswith('.pdf'):
full = ''
with pdfplumber.open(path) as pdf:
for page in pdf.pages:
txt = page.extract_text()
if txt:
full += txt + '\n'
return full
else:
with open(path, 'rb') as f:
data = f.read()
try:
return data.decode('utf-8')
except:
return data.decode('latin-1', errors='ignore')
# =========================
# Extracción NER de TIMEX
# =========================
def extract_timex(text: str):
text_norm = _normalise_spaces(_normalise_apostrophes(text))
inputs = ner_tok(text_norm, return_tensors="pt", truncation=True)
if torch.cuda.is_available():
inputs = {k: v.to("cuda") for k, v in inputs.items()}
with torch.no_grad():
logits = ner_mod(**inputs).logits
preds = torch.argmax(logits, dim=2)[0].cpu().numpy()
tokens = ner_tok.convert_ids_to_tokens(inputs["input_ids"][0])
entities = []
current = []
for tok, lab in zip(tokens, preds):
tag = ID2LABEL.get(lab, "O")
if tag == "B-TIMEX":
if current:
entities.append(ner_tok.convert_tokens_to_string(current).strip())
current = [tok]
elif tag == "I-TIMEX" and current:
current.append(tok)
else:
if current:
entities.append(ner_tok.convert_tokens_to_string(current).strip())
current = []
if current:
entities.append(ner_tok.convert_tokens_to_string(current).strip())
return [_clean_timex(e) for e in entities]
# =========================
# Normalización con Gemma + LoRA
# =========================
def normalize_timex(expr: str, dct: str) -> str:
prompt = (
f"<start_of_turn>user\n"
f"Tu tarea es normalizar la expresión temporal al formato TIMEX3, utilizando la fecha de anclaje (DCT) cuando sea necesaria.\n"
f"Fecha de Anclaje (DCT): {dct}\n"
f"Expresión Original: {expr}<end_of_turn>\n"
f"<start_of_turn>model\n"
)
device = next(norm_mod.parameters()).device
inputs = norm_tok(prompt, return_tensors="pt").to(device)
with torch.no_grad():
outputs = norm_mod.generate(
**inputs,
max_new_tokens=64,
eos_token_id=eos_id,
do_sample=False,
)
full_decoded = norm_tok.decode(
outputs[0, inputs.input_ids.shape[1]:],
skip_special_tokens=False
)
raw_tag = full_decoded.split("<end_of_turn>")[0].strip()
return raw_tag.replace("[", "<").replace("]", ">")
# =========================
# Pipeline principal
# =========================
def run_pipeline(files, raw_text, dct):
rows = []
file_list = files if isinstance(files, list) else ([files] if files else [])
# Texto pegado
if raw_text:
for line in raw_text.splitlines():
if line.strip():
for expr in extract_timex(line):
rows.append({
'Expresión': expr,
'Normalización': normalize_timex(expr, dct)
})
# Archivos subidos
for f in file_list:
content = read_file(f)
for line in content.splitlines():
if line.strip():
for expr in extract_timex(line):
rows.append({
'Expresión': expr,
'Normalización': normalize_timex(expr, dct)
})
df = pd.DataFrame(rows)
if df.empty:
df = pd.DataFrame([], columns=['Expresión', 'Normalización'])
return df, ""
# =========================
# Interfaz Gradio
# =========================
with gr.Blocks() as demo:
gr.Markdown("""
## TIMEX Extractor & Normalizer
Esta aplicación permite **extraer** expresiones temporales de textos o archivos (.txt, .pdf)
y **normalizarlas** a formato **TIMEX3**.
**Cómo usar:**
1. Sube uno o varios archivos en la columna izquierda.
2. Ajusta la *Fecha de Anclaje (DCT)*.
3. Escribe o pega tu texto en la columna derecha.
4. Pulsa **Procesar** para ver los resultados.
**Columnas de salida:**
- **Expresión**: la frase temporal extraída.
- **Normalización**: la etiqueta TIMEX3 generada.
""")
with gr.Row():
with gr.Column(scale=1):
files = gr.File(file_types=['.txt', '.pdf'], file_count='multiple', label='Archivos (.txt, .pdf)')
dct_input = gr.Textbox(value="2025-06-11", label="Fecha de Anclaje (YYYY-MM-DD)")
run_btn = gr.Button("Procesar")
with gr.Column(scale=2):
raw_text = gr.Textbox(lines=15, placeholder='Pega o escribe aquí tu texto...', label='Texto libre')
output_table = gr.Dataframe(headers=['Expresión', 'Normalización'], label="Resultados", type="pandas")
output_logs = gr.Textbox(label="Logs", lines=5, interactive=False)
download_btn = gr.Button("Descargar CSV")
csv_file_output = gr.File(label="Descargar resultados en CSV", visible=False)
# Acción principal de procesamiento
run_btn.click(
fn=run_pipeline,
inputs=[files, raw_text, dct_input],
outputs=[output_table, output_logs]
)
# Exportar a CSV
def export_csv(df):
csv_path = "resultados.csv"
df.to_csv(csv_path, index=False)
return gr.update(value=csv_path, visible=True)
download_btn.click(
fn=export_csv,
inputs=[output_table],
outputs=[csv_file_output]
)
# Lanzar la app (Spaces recogerá host/port)
if __name__ == "__main__":
demo.launch()