Spaces:
Sleeping
Sleeping
File size: 9,011 Bytes
a60c4ab e445da5 fe39050 e445da5 c578e92 e445da5 c578e92 e445da5 c578e92 fe39050 b8348c3 c578e92 a60c4ab c578e92 e445da5 c578e92 a328c77 a60c4ab e445da5 c578e92 a60c4ab c578e92 a60c4ab c578e92 e445da5 c578e92 e445da5 a60c4ab e445da5 c578e92 e445da5 c578e92 e445da5 c578e92 e445da5 c578e92 e445da5 c578e92 e445da5 c578e92 e445da5 c578e92 e445da5 c578e92 e445da5 c578e92 e445da5 c578e92 a60c4ab e445da5 a60c4ab 6771343 c578e92 6771343 c578e92 6771343 fe39050 c578e92 6771343 e445da5 c578e92 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 |
import os
import re
import unicodedata
import io
import torch
import gradio as gr
import pdfplumber
import pandas as pd
from transformers import (
AutoTokenizer,
AutoModelForTokenClassification,
AutoModelForCausalLM,
BitsAndBytesConfig,
)
from peft import PeftModel
# =========================
# Utilidades de normalización
# =========================
_SPACE_VARIANTS = r"[\u202f\u00a0\u2009\u200a\u2060]"
def _normalise_apostrophes(text: str) -> str:
return text.replace("´", "'").replace("’", "'")
def _normalise_spaces(text: str, collapse: bool = True) -> str:
text = re.sub(_SPACE_VARIANTS, " ", text)
text = unicodedata.normalize("NFKC", text)
if collapse:
text = re.sub(r"[ ]{2,}", " ", text)
return text.strip()
def _clean_timex(ent: str) -> str:
ent = ent.replace("</s>", "").strip()
return re.sub(r"[\.]+$", "", ent)
# =========================
# Identificadores de modelos
# =========================
NER_ID = "Rhulli/Roberta-ner-temporal-expresions-secondtrain"
ID2LABEL = {0: "O", 1: "B-TIMEX", 2: "I-TIMEX"}
BASE_ID = "google/gemma-2b-it"
ADAPTER_ID = "Rhulli/gemma-2b-it-TIMEX3"
# =========================
# Cuantización 4-bit (NF4)
# =========================
quant_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.float16,
)
# =========================
# Token de HF (si lo usas privado)
# =========================
HF_TOKEN = os.getenv("HF_TOKEN")
# =========================
# Carga de modelos
# =========================
def load_models():
# --- NER ---
ner_tok = AutoTokenizer.from_pretrained(NER_ID, token=HF_TOKEN)
ner_mod = AutoModelForTokenClassification.from_pretrained(NER_ID, token=HF_TOKEN)
ner_mod.eval()
if torch.cuda.is_available():
ner_mod.to("cuda")
# --- Base Causal LM (Gemma 2B-it) con 4-bit ---
base_mod = AutoModelForCausalLM.from_pretrained(
BASE_ID,
token=HF_TOKEN,
device_map="auto", # deja a Accelerate decidir
quantization_config=quant_config, # aplica 4-bit NF4
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
)
# --- Tokenizer del BASE (no del adapter) ---
norm_tok = AutoTokenizer.from_pretrained(BASE_ID, use_fast=True, token=HF_TOKEN)
# Asegurar pad_token si falta
if norm_tok.pad_token is None and norm_tok.eos_token is not None:
norm_tok.pad_token = norm_tok.eos_token
# --- Inyectar el LoRA SIN device_map (evitar meta/offload issues) ---
norm_mod = PeftModel.from_pretrained(
base_mod,
ADAPTER_ID,
token=HF_TOKEN,
is_trainable=False,
offload_state_dict=False,
)
norm_mod.eval()
return ner_tok, ner_mod, norm_tok, norm_mod
# Carga inicial de los modelos
ner_tok, ner_mod, norm_tok, norm_mod = load_models()
# Determinar eos_id de manera segura
try:
eos_id = norm_tok.convert_tokens_to_ids("<end_of_turn>")
if eos_id is None or eos_id == norm_tok.unk_token_id:
eos_id = norm_tok.eos_token_id
except Exception:
eos_id = norm_tok.eos_token_id
# =========================
# Lectura de archivos (.txt, .pdf)
# =========================
def read_file(file_obj) -> str:
path = file_obj.name
if path.lower().endswith('.pdf'):
full = ''
with pdfplumber.open(path) as pdf:
for page in pdf.pages:
txt = page.extract_text()
if txt:
full += txt + '\n'
return full
else:
with open(path, 'rb') as f:
data = f.read()
try:
return data.decode('utf-8')
except:
return data.decode('latin-1', errors='ignore')
# =========================
# Extracción NER de TIMEX
# =========================
def extract_timex(text: str):
text_norm = _normalise_spaces(_normalise_apostrophes(text))
inputs = ner_tok(text_norm, return_tensors="pt", truncation=True)
if torch.cuda.is_available():
inputs = {k: v.to("cuda") for k, v in inputs.items()}
with torch.no_grad():
logits = ner_mod(**inputs).logits
preds = torch.argmax(logits, dim=2)[0].cpu().numpy()
tokens = ner_tok.convert_ids_to_tokens(inputs["input_ids"][0])
entities = []
current = []
for tok, lab in zip(tokens, preds):
tag = ID2LABEL.get(lab, "O")
if tag == "B-TIMEX":
if current:
entities.append(ner_tok.convert_tokens_to_string(current).strip())
current = [tok]
elif tag == "I-TIMEX" and current:
current.append(tok)
else:
if current:
entities.append(ner_tok.convert_tokens_to_string(current).strip())
current = []
if current:
entities.append(ner_tok.convert_tokens_to_string(current).strip())
return [_clean_timex(e) for e in entities]
# =========================
# Normalización con Gemma + LoRA
# =========================
def normalize_timex(expr: str, dct: str) -> str:
prompt = (
f"<start_of_turn>user\n"
f"Tu tarea es normalizar la expresión temporal al formato TIMEX3, utilizando la fecha de anclaje (DCT) cuando sea necesaria.\n"
f"Fecha de Anclaje (DCT): {dct}\n"
f"Expresión Original: {expr}<end_of_turn>\n"
f"<start_of_turn>model\n"
)
device = next(norm_mod.parameters()).device
inputs = norm_tok(prompt, return_tensors="pt").to(device)
with torch.no_grad():
outputs = norm_mod.generate(
**inputs,
max_new_tokens=64,
eos_token_id=eos_id,
do_sample=False,
)
full_decoded = norm_tok.decode(
outputs[0, inputs.input_ids.shape[1]:],
skip_special_tokens=False
)
raw_tag = full_decoded.split("<end_of_turn>")[0].strip()
return raw_tag.replace("[", "<").replace("]", ">")
# =========================
# Pipeline principal
# =========================
def run_pipeline(files, raw_text, dct):
rows = []
file_list = files if isinstance(files, list) else ([files] if files else [])
# Texto pegado
if raw_text:
for line in raw_text.splitlines():
if line.strip():
for expr in extract_timex(line):
rows.append({
'Expresión': expr,
'Normalización': normalize_timex(expr, dct)
})
# Archivos subidos
for f in file_list:
content = read_file(f)
for line in content.splitlines():
if line.strip():
for expr in extract_timex(line):
rows.append({
'Expresión': expr,
'Normalización': normalize_timex(expr, dct)
})
df = pd.DataFrame(rows)
if df.empty:
df = pd.DataFrame([], columns=['Expresión', 'Normalización'])
return df, ""
# =========================
# Interfaz Gradio
# =========================
with gr.Blocks() as demo:
gr.Markdown("""
## TIMEX Extractor & Normalizer
Esta aplicación permite **extraer** expresiones temporales de textos o archivos (.txt, .pdf)
y **normalizarlas** a formato **TIMEX3**.
**Cómo usar:**
1. Sube uno o varios archivos en la columna izquierda.
2. Ajusta la *Fecha de Anclaje (DCT)*.
3. Escribe o pega tu texto en la columna derecha.
4. Pulsa **Procesar** para ver los resultados.
**Columnas de salida:**
- **Expresión**: la frase temporal extraída.
- **Normalización**: la etiqueta TIMEX3 generada.
""")
with gr.Row():
with gr.Column(scale=1):
files = gr.File(file_types=['.txt', '.pdf'], file_count='multiple', label='Archivos (.txt, .pdf)')
dct_input = gr.Textbox(value="2025-06-11", label="Fecha de Anclaje (YYYY-MM-DD)")
run_btn = gr.Button("Procesar")
with gr.Column(scale=2):
raw_text = gr.Textbox(lines=15, placeholder='Pega o escribe aquí tu texto...', label='Texto libre')
output_table = gr.Dataframe(headers=['Expresión', 'Normalización'], label="Resultados", type="pandas")
output_logs = gr.Textbox(label="Logs", lines=5, interactive=False)
download_btn = gr.Button("Descargar CSV")
csv_file_output = gr.File(label="Descargar resultados en CSV", visible=False)
# Acción principal de procesamiento
run_btn.click(
fn=run_pipeline,
inputs=[files, raw_text, dct_input],
outputs=[output_table, output_logs]
)
# Exportar a CSV
def export_csv(df):
csv_path = "resultados.csv"
df.to_csv(csv_path, index=False)
return gr.update(value=csv_path, visible=True)
download_btn.click(
fn=export_csv,
inputs=[output_table],
outputs=[csv_file_output]
)
# Lanzar la app (Spaces recogerá host/port)
if __name__ == "__main__":
demo.launch()
|