Tomy07417 commited on
Commit
10f78c0
·
1 Parent(s): 265be3c

Deploy Gradio app loading model from HF Hub

Browse files
Files changed (2) hide show
  1. app.py +83 -0
  2. requirements.txt +199 -0
app.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ os.environ["CUDA_VISIBLE_DEVICES"] = "-1" # CPU
3
+
4
+ import gradio as gr
5
+ import tensorflow as tf
6
+ from huggingface_hub import hf_hub_download
7
+ from transformers import AutoTokenizer, TFAutoModel
8
+
9
+
10
+ @tf.keras.utils.register_keras_serializable()
11
+ class DistilBertLayer(tf.keras.layers.Layer):
12
+ def __init__(self, model_name="vinai/bertweet-base", **kwargs):
13
+ super().__init__(**kwargs)
14
+ self.model_name = model_name
15
+ self.bert = TFAutoModel.from_pretrained(model_name, from_pt=True)
16
+
17
+ def call(self, inputs):
18
+ input_ids, attention_mask = inputs
19
+ outputs = self.bert(
20
+ input_ids=input_ids,
21
+ attention_mask=attention_mask,
22
+ training=False
23
+ )
24
+ return outputs.last_hidden_state
25
+
26
+ def get_config(self):
27
+ config = super().get_config()
28
+ config.update({"model_name": self.model_name})
29
+ return config
30
+
31
+
32
+ # 1) Repo donde subiste el .keras (MODELS, no Spaces)
33
+ MODEL_REPO = "tomy07417/disaster-tweets-bertweet-gru" # <-- CAMBIÁ ESTO
34
+ MODEL_FILE = "bertweet_gru_model.keras" # <-- nombre exacto en el repo
35
+
36
+ # 2) Descarga con cache (no lo baja cada vez)
37
+ model_path = hf_hub_download(
38
+ repo_id=MODEL_REPO,
39
+ filename=MODEL_FILE,
40
+ repo_type="model"
41
+ )
42
+
43
+ # 3) Cargar el modelo desde el path descargado
44
+ model = tf.keras.models.load_model(
45
+ model_path,
46
+ custom_objects={"DistilBertLayer": DistilBertLayer},
47
+ compile=False
48
+ )
49
+
50
+ tokenizer = AutoTokenizer.from_pretrained("vinai/bertweet-base")
51
+
52
+
53
+ def predict(text):
54
+ inputs = tokenizer(
55
+ [text],
56
+ max_length=50,
57
+ truncation=True,
58
+ padding="max_length",
59
+ return_tensors="tf"
60
+ )
61
+
62
+ input_ids = inputs["input_ids"]
63
+ attention_mask = inputs["attention_mask"]
64
+
65
+ # si tu salida es (1,) sigmoid:
66
+ prob = model.predict([input_ids, attention_mask])[0][0]
67
+ pred = bool(prob > 0.5)
68
+
69
+ return {"prob": float(prob), "pred": pred}
70
+
71
+
72
+ demo = gr.Interface(
73
+ fn=predict,
74
+ inputs=gr.Textbox(lines=3, label="Tweet"),
75
+ outputs=gr.JSON(label="Result"),
76
+ title="Tweet classifier",
77
+ description="Paste a tweet in English"
78
+ )
79
+
80
+ if __name__ == "__main__":
81
+ # En Spaces NO uses share=True
82
+ demo.launch()
83
+
requirements.txt ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ absl-py==2.3.1
2
+ annotated-types==0.7.0
3
+ anyio==4.11.0
4
+ argon2-cffi==25.1.0
5
+ argon2-cffi-bindings==25.1.0
6
+ arrow==1.4.0
7
+ asttokens==3.0.0
8
+ astunparse==1.6.3
9
+ async-lru==2.0.5
10
+ attrs==25.4.0
11
+ babel==2.17.0
12
+ beautifulsoup4==4.14.2
13
+ bleach==6.3.0
14
+ blis==1.3.0
15
+ catalogue==2.0.10
16
+ certifi==2025.10.5
17
+ cffi==2.0.0
18
+ charset-normalizer==3.4.4
19
+ click==8.3.0
20
+ cloudpathlib==0.23.0
21
+ comm==0.2.3
22
+ confection==0.1.5
23
+ contourpy==1.3.3
24
+ cycler==0.12.1
25
+ cymem==2.0.11
26
+ debugpy==1.8.17
27
+ decorator==5.2.1
28
+ defusedxml==0.7.1
29
+ emoji==2.15.0
30
+ en_core_web_sm @ https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.8.0/en_core_web_sm-3.8.0-py3-none-any.whl#sha256=1932429db727d4bff3deed6b34cfc05df17794f4a52eeb26cf8928f7c1a0fb85
31
+ executing==2.2.1
32
+ fastjsonschema==2.21.2
33
+ filelock==3.20.0
34
+ flatbuffers==25.9.23
35
+ fonttools==4.60.1
36
+ fqdn==1.5.1
37
+ fsspec==2025.10.0
38
+ gast==0.6.0
39
+ gensim==4.4.0
40
+ google-pasta==0.2.0
41
+ grpcio==1.76.0
42
+ gradio>=4.0
43
+ h11==0.16.0
44
+ h5py==3.15.1
45
+ hf-xet==1.2.0
46
+ httpcore==1.0.9
47
+ httpx==0.28.1
48
+ huggingface-hub==0.36.0
49
+ idna==3.11
50
+ ipykernel==7.1.0
51
+ ipython==9.6.0
52
+ ipython_pygments_lexers==1.1.1
53
+ isoduration==20.11.0
54
+ jedi==0.19.2
55
+ Jinja2==3.1.6
56
+ joblib==1.5.2
57
+ json5==0.12.1
58
+ jsonpointer==3.0.0
59
+ jsonschema==4.25.1
60
+ jsonschema-specifications==2025.9.1
61
+ jupyter-events==0.12.0
62
+ jupyter-lsp==2.3.0
63
+ jupyter_client==8.6.3
64
+ jupyter_core==5.9.1
65
+ jupyter_server==2.17.0
66
+ jupyter_server_terminals==0.5.3
67
+ jupyterlab==4.4.10
68
+ jupyterlab_pygments==0.3.0
69
+ jupyterlab_server==2.28.0
70
+ keras==3.12.0
71
+ kiwisolver==1.4.9
72
+ langcodes==3.5.0
73
+ language_data==1.3.0
74
+ lark==1.3.1
75
+ libclang==18.1.1
76
+ marisa-trie==1.3.1
77
+ Markdown==3.10
78
+ markdown-it-py==4.0.0
79
+ MarkupSafe==3.0.3
80
+ matplotlib==3.10.7
81
+ matplotlib-inline==0.2.1
82
+ mdurl==0.1.2
83
+ mistune==3.1.4
84
+ ml_dtypes==0.5.3
85
+ mpmath==1.3.0
86
+ murmurhash==1.0.13
87
+ namex==0.1.0
88
+ narwhals==2.10.2
89
+ nbclient==0.10.2
90
+ nbconvert==7.16.6
91
+ nbformat==5.10.4
92
+ nest-asyncio==1.6.0
93
+ networkx==3.5
94
+ nltk==3.9.2
95
+ notebook==7.4.7
96
+ notebook_shim==0.2.4
97
+ numpy==2.3.4
98
+ nvidia-cublas-cu12==12.8.4.1
99
+ nvidia-cuda-cupti-cu12==12.8.90
100
+ nvidia-cuda-nvrtc-cu12==12.8.93
101
+ nvidia-cuda-runtime-cu12==12.8.90
102
+ nvidia-cudnn-cu12==9.10.2.21
103
+ nvidia-cufft-cu12==11.3.3.83
104
+ nvidia-cufile-cu12==1.13.1.3
105
+ nvidia-curand-cu12==10.3.9.90
106
+ nvidia-cusolver-cu12==11.7.3.90
107
+ nvidia-cusparse-cu12==12.5.8.93
108
+ nvidia-cusparselt-cu12==0.7.1
109
+ nvidia-nccl-cu12==2.27.5
110
+ nvidia-nvjitlink-cu12==12.8.93
111
+ nvidia-nvshmem-cu12==3.3.20
112
+ nvidia-nvtx-cu12==12.8.90
113
+ opt_einsum==3.4.0
114
+ optree==0.17.0
115
+ packaging==25.0
116
+ pandas==2.3.3
117
+ pandocfilters==1.5.1
118
+ parso==0.8.5
119
+ pexpect==4.9.0
120
+ pillow==12.0.0
121
+ platformdirs==4.5.0
122
+ plotly==6.4.0
123
+ preshed==3.0.10
124
+ prometheus_client==0.23.1
125
+ prompt_toolkit==3.0.52
126
+ protobuf==6.33.0
127
+ psutil==7.1.3
128
+ ptyprocess==0.7.0
129
+ pure_eval==0.2.3
130
+ pycparser==2.23
131
+ pydantic==2.12.3
132
+ pydantic_core==2.41.4
133
+ Pygments==2.19.2
134
+ pyparsing==3.2.5
135
+ python-dateutil==2.9.0.post0
136
+ python-json-logger==4.0.0
137
+ pytz==2025.2
138
+ PyYAML==6.0.3
139
+ pyzmq==27.1.0
140
+ referencing==0.37.0
141
+ regex==2025.11.3
142
+ requests==2.32.5
143
+ rfc3339-validator==0.1.4
144
+ rfc3986-validator==0.1.1
145
+ rfc3987-syntax==1.1.0
146
+ rich==14.2.0
147
+ rpds-py==0.28.0
148
+ safetensors==0.6.2
149
+ scikit-learn==1.7.2
150
+ scipy==1.16.3
151
+ seaborn==0.13.2
152
+ Send2Trash==1.8.3
153
+ sentence-transformers==5.1.2
154
+ setuptools==80.9.0
155
+ shellingham==1.5.4
156
+ six==1.17.0
157
+ smart_open==7.4.4
158
+ sniffio==1.3.1
159
+ soupsieve==2.8
160
+ spacy==3.8.7
161
+ spacy-legacy==3.0.12
162
+ spacy-loggers==1.0.5
163
+ srsly==2.5.1
164
+ stack-data==0.6.3
165
+ sympy==1.14.0
166
+ tensorboard==2.20.0
167
+ tensorboard-data-server==0.7.2
168
+ tensorflow_cpu==2.20.0
169
+ termcolor==3.2.0
170
+ terminado==0.18.1
171
+ tf_keras==2.20.1
172
+ thinc==8.3.6
173
+ threadpoolctl==3.6.0
174
+ tinycss2==1.4.0
175
+ tokenizers==0.22.1
176
+ torch==2.9.1
177
+ torchaudio==2.9.1
178
+ torchvision==0.24.1
179
+ tornado==6.5.2
180
+ tqdm==4.67.1
181
+ traitlets==5.14.3
182
+ transformers==4.57.1
183
+ triton==3.5.1
184
+ typer==0.20.0
185
+ typing-inspection==0.4.2
186
+ typing_extensions==4.15.0
187
+ tzdata==2025.2
188
+ uri-template==1.3.0
189
+ urllib3==2.5.0
190
+ wasabi==1.1.3
191
+ wcwidth==0.2.14
192
+ weasel==0.4.1
193
+ webcolors==25.10.0
194
+ webencodings==0.5.1
195
+ websocket-client==1.9.0
196
+ Werkzeug==3.1.3
197
+ wheel==0.45.1
198
+ wrapt==2.0.0
199
+ xgboost==3.1.1