Rhulli commited on
Commit
bcbefbd
verified
1 Parent(s): b8348c3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -10
app.py CHANGED
@@ -10,7 +10,6 @@ from transformers import (
10
  AutoTokenizer,
11
  AutoModelForTokenClassification,
12
  AutoModelForCausalLM,
13
- BitsAndBytesConfig,
14
  )
15
  from peft import PeftModel
16
 
@@ -37,13 +36,6 @@ ID2LABEL = {0: "O", 1: "B-TIMEX", 2: "I-TIMEX"}
37
  BASE_ID = "google/gemma-2b-it"
38
  ADAPTER_ID = "Rhulli/gemma-2b-it-TIMEX3"
39
 
40
- # --- Configuraci贸n de cuantizaci贸n para el modelo de normalizaci贸n ---
41
- quant_config = BitsAndBytesConfig(
42
- load_in_4bit=True,
43
- bnb_4bit_quant_type="nf4",
44
- bnb_4bit_compute_dtype=torch.float16,
45
- )
46
-
47
  # --- Leer el token del entorno (a帽adido como Repository Secret) ---
48
  HF_TOKEN = os.getenv("HF_TOKEN")
49
 
@@ -55,13 +47,14 @@ def load_models():
55
  if torch.cuda.is_available():
56
  ner_mod.to("cuda")
57
 
58
- # Carga del modelo de normalizaci贸n (LoRA + 4bit)
59
  base_mod = AutoModelForCausalLM.from_pretrained(
60
  BASE_ID,
61
- quantization_config=quant_config,
62
  device_map="auto",
63
  token=HF_TOKEN
64
  )
 
 
65
  norm_tok = AutoTokenizer.from_pretrained(ADAPTER_ID, use_fast=True, token=HF_TOKEN)
66
  norm_mod = PeftModel.from_pretrained(
67
  base_mod,
 
10
  AutoTokenizer,
11
  AutoModelForTokenClassification,
12
  AutoModelForCausalLM,
 
13
  )
14
  from peft import PeftModel
15
 
 
36
  BASE_ID = "google/gemma-2b-it"
37
  ADAPTER_ID = "Rhulli/gemma-2b-it-TIMEX3"
38
 
 
 
 
 
 
 
 
39
  # --- Leer el token del entorno (a帽adido como Repository Secret) ---
40
  HF_TOKEN = os.getenv("HF_TOKEN")
41
 
 
47
  if torch.cuda.is_available():
48
  ner_mod.to("cuda")
49
 
50
+ # Carga del modelo base de normalizaci贸n (sin cuantizaci贸n)
51
  base_mod = AutoModelForCausalLM.from_pretrained(
52
  BASE_ID,
 
53
  device_map="auto",
54
  token=HF_TOKEN
55
  )
56
+
57
+ # Carga del tokenizer y adaptador LoRA
58
  norm_tok = AutoTokenizer.from_pretrained(ADAPTER_ID, use_fast=True, token=HF_TOKEN)
59
  norm_mod = PeftModel.from_pretrained(
60
  base_mod,