Rhulli commited on
Commit
c578e92
verified
1 Parent(s): b2c6071

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +90 -41
app.py CHANGED
@@ -14,7 +14,9 @@ from transformers import (
14
  )
15
  from peft import PeftModel
16
 
17
- # --- Funciones de normalizaci贸n y limpieza ---
 
 
18
  _SPACE_VARIANTS = r"[\u202f\u00a0\u2009\u200a\u2060]"
19
 
20
  def _normalise_apostrophes(text: str) -> str:
@@ -31,40 +33,62 @@ def _clean_timex(ent: str) -> str:
31
  ent = ent.replace("</s>", "").strip()
32
  return re.sub(r"[\.]+$", "", ent)
33
 
34
- # --- Identificadores de los modelos ----
 
 
35
  NER_ID = "Rhulli/Roberta-ner-temporal-expresions-secondtrain"
36
  ID2LABEL = {0: "O", 1: "B-TIMEX", 2: "I-TIMEX"}
37
  BASE_ID = "google/gemma-2b-it"
38
  ADAPTER_ID = "Rhulli/gemma-2b-it-TIMEX3"
39
 
40
- # --- Configuraci贸n de cuantizaci贸n para el modelo de normalizaci贸n ----
 
 
41
  quant_config = BitsAndBytesConfig(
42
  load_in_4bit=True,
43
  bnb_4bit_quant_type="nf4",
44
  bnb_4bit_compute_dtype=torch.float16,
45
  )
46
 
47
- # --- Leer el token del entorno (a帽adido como Repository Secret) ----
 
 
48
  HF_TOKEN = os.getenv("HF_TOKEN")
49
 
 
 
 
50
  def load_models():
 
51
  ner_tok = AutoTokenizer.from_pretrained(NER_ID, token=HF_TOKEN)
52
  ner_mod = AutoModelForTokenClassification.from_pretrained(NER_ID, token=HF_TOKEN)
53
  ner_mod.eval()
54
  if torch.cuda.is_available():
55
  ner_mod.to("cuda")
56
 
 
57
  base_mod = AutoModelForCausalLM.from_pretrained(
58
  BASE_ID,
59
- device_map="auto",
60
- token=HF_TOKEN
 
 
 
61
  )
62
- norm_tok = AutoTokenizer.from_pretrained(ADAPTER_ID, use_fast=True, token=HF_TOKEN)
 
 
 
 
 
 
 
63
  norm_mod = PeftModel.from_pretrained(
64
  base_mod,
65
  ADAPTER_ID,
66
- device_map="auto",
67
- token=HF_TOKEN
 
68
  )
69
  norm_mod.eval()
70
 
@@ -72,9 +96,18 @@ def load_models():
72
 
73
  # Carga inicial de los modelos
74
  ner_tok, ner_mod, norm_tok, norm_mod = load_models()
75
- eos_id = norm_tok.convert_tokens_to_ids("<end_of_turn>")
76
 
77
- # --- Lectura de archivos ---
 
 
 
 
 
 
 
 
 
 
78
  def read_file(file_obj) -> str:
79
  path = file_obj.name
80
  if path.lower().endswith('.pdf'):
@@ -93,7 +126,9 @@ def read_file(file_obj) -> str:
93
  except:
94
  return data.decode('latin-1', errors='ignore')
95
 
96
- # --- Procesamiento de texto ---
 
 
97
  def extract_timex(text: str):
98
  text_norm = _normalise_spaces(_normalise_apostrophes(text))
99
  inputs = ner_tok(text_norm, return_tensors="pt", truncation=True)
@@ -124,6 +159,9 @@ def extract_timex(text: str):
124
 
125
  return [_clean_timex(e) for e in entities]
126
 
 
 
 
127
  def normalize_timex(expr: str, dct: str) -> str:
128
  prompt = (
129
  f"<start_of_turn>user\n"
@@ -132,8 +170,15 @@ def normalize_timex(expr: str, dct: str) -> str:
132
  f"Expresi贸n Original: {expr}<end_of_turn>\n"
133
  f"<start_of_turn>model\n"
134
  )
135
- inputs = norm_tok(prompt, return_tensors="pt").to(norm_mod.device)
136
- outputs = norm_mod.generate(**inputs, max_new_tokens=64, eos_token_id=eos_id)
 
 
 
 
 
 
 
137
 
138
  full_decoded = norm_tok.decode(
139
  outputs[0, inputs.input_ids.shape[1]:],
@@ -142,11 +187,14 @@ def normalize_timex(expr: str, dct: str) -> str:
142
  raw_tag = full_decoded.split("<end_of_turn>")[0].strip()
143
  return raw_tag.replace("[", "<").replace("]", ">")
144
 
145
- # --- Pipeline principal ---
 
 
146
  def run_pipeline(files, raw_text, dct):
147
  rows = []
148
  file_list = files if isinstance(files, list) else ([files] if files else [])
149
 
 
150
  if raw_text:
151
  for line in raw_text.splitlines():
152
  if line.strip():
@@ -156,6 +204,7 @@ def run_pipeline(files, raw_text, dct):
156
  'Normalizaci贸n': normalize_timex(expr, dct)
157
  })
158
 
 
159
  for f in file_list:
160
  content = read_file(f)
161
  for line in content.splitlines():
@@ -172,29 +221,30 @@ def run_pipeline(files, raw_text, dct):
172
 
173
  return df, ""
174
 
175
- # --- Interfaz Gradio ---
 
 
176
  with gr.Blocks() as demo:
177
- gr.Markdown(
178
- ## TIMEX Extractor & Normalizer
179
- """"
180
- Esta aplicaci贸n permite extraer expresiones temporales de textos o archivos (.txt)
181
- y normalizarlas a formato TIMEX3.
182
-
183
- **C贸mo usar:**
184
- - Sube uno o varios archivos en la columna izquierda.
185
- - Ajusta la *Fecha de Anclaje (DCT)* justo debajo de los archivos.
186
- - Escribe o pega tu texto en la columna derecha.
187
- - Pulsa **Procesar** para ver los resultados en la tabla debajo.
188
-
189
- **Columnas de salida:**
190
- - *Expresi贸n*: la frase temporal extra铆da.
191
- - *Normalizaci贸n*: la etiqueta TIMEX3 generada.
192
- """
193
- )
194
 
195
  with gr.Row():
196
  with gr.Column(scale=1):
197
- files = gr.File(file_types=['.txt'], file_count='multiple', label='Archivos (.txt)')
198
  dct_input = gr.Textbox(value="2025-06-11", label="Fecha de Anclaje (YYYY-MM-DD)")
199
  run_btn = gr.Button("Procesar")
200
  with gr.Column(scale=2):
@@ -203,29 +253,28 @@ with gr.Blocks() as demo:
203
  output_table = gr.Dataframe(headers=['Expresi贸n', 'Normalizaci贸n'], label="Resultados", type="pandas")
204
  output_logs = gr.Textbox(label="Logs", lines=5, interactive=False)
205
 
206
- # Despu茅s de definir output_table y output_logs:
207
  download_btn = gr.Button("Descargar CSV")
208
- csv_file_output = gr.File(label="Descargar resultados en CSV", visible=False)
209
 
210
- # El click de procesar normales.
211
  run_btn.click(
212
  fn=run_pipeline,
213
  inputs=[files, raw_text, dct_input],
214
  outputs=[output_table, output_logs]
215
  )
216
 
217
- # Funci贸n para exportar a CSV
218
  def export_csv(df):
219
  csv_path = "resultados.csv"
220
  df.to_csv(csv_path, index=False)
221
  return gr.update(value=csv_path, visible=True)
222
 
223
- # Asociar el bot贸n de descarga al CSV
224
  download_btn.click(
225
  fn=export_csv,
226
  inputs=[output_table],
227
  outputs=[csv_file_output]
228
  )
229
 
230
- # Lanzar la app
231
- demo.launch()
 
 
14
  )
15
  from peft import PeftModel
16
 
17
+ # =========================
18
+ # Utilidades de normalizaci贸n
19
+ # =========================
20
  _SPACE_VARIANTS = r"[\u202f\u00a0\u2009\u200a\u2060]"
21
 
22
  def _normalise_apostrophes(text: str) -> str:
 
33
  ent = ent.replace("</s>", "").strip()
34
  return re.sub(r"[\.]+$", "", ent)
35
 
36
+ # =========================
37
+ # Identificadores de modelos
38
+ # =========================
39
  NER_ID = "Rhulli/Roberta-ner-temporal-expresions-secondtrain"
40
  ID2LABEL = {0: "O", 1: "B-TIMEX", 2: "I-TIMEX"}
41
  BASE_ID = "google/gemma-2b-it"
42
  ADAPTER_ID = "Rhulli/gemma-2b-it-TIMEX3"
43
 
44
+ # =========================
45
+ # Cuantizaci贸n 4-bit (NF4)
46
+ # =========================
47
  quant_config = BitsAndBytesConfig(
48
  load_in_4bit=True,
49
  bnb_4bit_quant_type="nf4",
50
  bnb_4bit_compute_dtype=torch.float16,
51
  )
52
 
53
+ # =========================
54
+ # Token de HF (si lo usas privado)
55
+ # =========================
56
  HF_TOKEN = os.getenv("HF_TOKEN")
57
 
58
+ # =========================
59
+ # Carga de modelos
60
+ # =========================
61
  def load_models():
62
+ # --- NER ---
63
  ner_tok = AutoTokenizer.from_pretrained(NER_ID, token=HF_TOKEN)
64
  ner_mod = AutoModelForTokenClassification.from_pretrained(NER_ID, token=HF_TOKEN)
65
  ner_mod.eval()
66
  if torch.cuda.is_available():
67
  ner_mod.to("cuda")
68
 
69
+ # --- Base Causal LM (Gemma 2B-it) con 4-bit ---
70
  base_mod = AutoModelForCausalLM.from_pretrained(
71
  BASE_ID,
72
+ token=HF_TOKEN,
73
+ device_map="auto", # deja a Accelerate decidir
74
+ quantization_config=quant_config, # aplica 4-bit NF4
75
+ torch_dtype=torch.float16,
76
+ low_cpu_mem_usage=True,
77
  )
78
+
79
+ # --- Tokenizer del BASE (no del adapter) ---
80
+ norm_tok = AutoTokenizer.from_pretrained(BASE_ID, use_fast=True, token=HF_TOKEN)
81
+ # Asegurar pad_token si falta
82
+ if norm_tok.pad_token is None and norm_tok.eos_token is not None:
83
+ norm_tok.pad_token = norm_tok.eos_token
84
+
85
+ # --- Inyectar el LoRA SIN device_map (evitar meta/offload issues) ---
86
  norm_mod = PeftModel.from_pretrained(
87
  base_mod,
88
  ADAPTER_ID,
89
+ token=HF_TOKEN,
90
+ is_trainable=False,
91
+ offload_state_dict=False,
92
  )
93
  norm_mod.eval()
94
 
 
96
 
97
  # Carga inicial de los modelos
98
  ner_tok, ner_mod, norm_tok, norm_mod = load_models()
 
99
 
100
+ # Determinar eos_id de manera segura
101
+ try:
102
+ eos_id = norm_tok.convert_tokens_to_ids("<end_of_turn>")
103
+ if eos_id is None or eos_id == norm_tok.unk_token_id:
104
+ eos_id = norm_tok.eos_token_id
105
+ except Exception:
106
+ eos_id = norm_tok.eos_token_id
107
+
108
+ # =========================
109
+ # Lectura de archivos (.txt, .pdf)
110
+ # =========================
111
  def read_file(file_obj) -> str:
112
  path = file_obj.name
113
  if path.lower().endswith('.pdf'):
 
126
  except:
127
  return data.decode('latin-1', errors='ignore')
128
 
129
+ # =========================
130
+ # Extracci贸n NER de TIMEX
131
+ # =========================
132
  def extract_timex(text: str):
133
  text_norm = _normalise_spaces(_normalise_apostrophes(text))
134
  inputs = ner_tok(text_norm, return_tensors="pt", truncation=True)
 
159
 
160
  return [_clean_timex(e) for e in entities]
161
 
162
+ # =========================
163
+ # Normalizaci贸n con Gemma + LoRA
164
+ # =========================
165
  def normalize_timex(expr: str, dct: str) -> str:
166
  prompt = (
167
  f"<start_of_turn>user\n"
 
170
  f"Expresi贸n Original: {expr}<end_of_turn>\n"
171
  f"<start_of_turn>model\n"
172
  )
173
+ device = next(norm_mod.parameters()).device
174
+ inputs = norm_tok(prompt, return_tensors="pt").to(device)
175
+ with torch.no_grad():
176
+ outputs = norm_mod.generate(
177
+ **inputs,
178
+ max_new_tokens=64,
179
+ eos_token_id=eos_id,
180
+ do_sample=False,
181
+ )
182
 
183
  full_decoded = norm_tok.decode(
184
  outputs[0, inputs.input_ids.shape[1]:],
 
187
  raw_tag = full_decoded.split("<end_of_turn>")[0].strip()
188
  return raw_tag.replace("[", "<").replace("]", ">")
189
 
190
+ # =========================
191
+ # Pipeline principal
192
+ # =========================
193
  def run_pipeline(files, raw_text, dct):
194
  rows = []
195
  file_list = files if isinstance(files, list) else ([files] if files else [])
196
 
197
+ # Texto pegado
198
  if raw_text:
199
  for line in raw_text.splitlines():
200
  if line.strip():
 
204
  'Normalizaci贸n': normalize_timex(expr, dct)
205
  })
206
 
207
+ # Archivos subidos
208
  for f in file_list:
209
  content = read_file(f)
210
  for line in content.splitlines():
 
221
 
222
  return df, ""
223
 
224
+ # =========================
225
+ # Interfaz Gradio
226
+ # =========================
227
  with gr.Blocks() as demo:
228
+ gr.Markdown("""
229
+ ## TIMEX Extractor & Normalizer
230
+
231
+ Esta aplicaci贸n permite **extraer** expresiones temporales de textos o archivos (.txt, .pdf)
232
+ y **normalizarlas** a formato **TIMEX3**.
233
+
234
+ **C贸mo usar:**
235
+ 1. Sube uno o varios archivos en la columna izquierda.
236
+ 2. Ajusta la *Fecha de Anclaje (DCT)*.
237
+ 3. Escribe o pega tu texto en la columna derecha.
238
+ 4. Pulsa **Procesar** para ver los resultados.
239
+
240
+ **Columnas de salida:**
241
+ - **Expresi贸n**: la frase temporal extra铆da.
242
+ - **Normalizaci贸n**: la etiqueta TIMEX3 generada.
243
+ """)
 
244
 
245
  with gr.Row():
246
  with gr.Column(scale=1):
247
+ files = gr.File(file_types=['.txt', '.pdf'], file_count='multiple', label='Archivos (.txt, .pdf)')
248
  dct_input = gr.Textbox(value="2025-06-11", label="Fecha de Anclaje (YYYY-MM-DD)")
249
  run_btn = gr.Button("Procesar")
250
  with gr.Column(scale=2):
 
253
  output_table = gr.Dataframe(headers=['Expresi贸n', 'Normalizaci贸n'], label="Resultados", type="pandas")
254
  output_logs = gr.Textbox(label="Logs", lines=5, interactive=False)
255
 
 
256
  download_btn = gr.Button("Descargar CSV")
257
+ csv_file_output = gr.File(label="Descargar resultados en CSV", visible=False)
258
 
259
+ # Acci贸n principal de procesamiento
260
  run_btn.click(
261
  fn=run_pipeline,
262
  inputs=[files, raw_text, dct_input],
263
  outputs=[output_table, output_logs]
264
  )
265
 
266
+ # Exportar a CSV
267
  def export_csv(df):
268
  csv_path = "resultados.csv"
269
  df.to_csv(csv_path, index=False)
270
  return gr.update(value=csv_path, visible=True)
271
 
 
272
  download_btn.click(
273
  fn=export_csv,
274
  inputs=[output_table],
275
  outputs=[csv_file_output]
276
  )
277
 
278
+ # Lanzar la app (Spaces recoger谩 host/port)
279
+ if __name__ == "__main__":
280
+ demo.launch()