File size: 2,237 Bytes
10f78c0 a8376a6 10f78c0 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 | import os
os.environ["CUDA_VISIBLE_DEVICES"] = "-1" # CPU
import gradio as gr
import tensorflow as tf
from huggingface_hub import hf_hub_download
from transformers import AutoTokenizer, TFAutoModel
@tf.keras.utils.register_keras_serializable()
class DistilBertLayer(tf.keras.layers.Layer):
def __init__(self, model_name="vinai/bertweet-base", **kwargs):
super().__init__(**kwargs)
self.model_name = model_name
self.bert = TFAutoModel.from_pretrained(model_name, from_pt=True)
def call(self, inputs):
input_ids, attention_mask = inputs
outputs = self.bert(
input_ids=input_ids,
attention_mask=attention_mask,
training=False
)
return outputs.last_hidden_state
def get_config(self):
config = super().get_config()
config.update({"model_name": self.model_name})
return config
# 1) Repo donde subiste el .keras (MODELS, no Spaces)
MODEL_REPO = "tomy07417/disaster-tweets-bertweet-gru" # <-- CAMBIÁ ESTO
MODEL_FILE = "bertweet_gru_model.keras" # <-- nombre exacto en el repo
# 2) Descarga con cache (no lo baja cada vez)
model_path = hf_hub_download(
repo_id=MODEL_REPO,
filename=MODEL_FILE,
repo_type="model"
)
# 3) Cargar el modelo desde el path descargado
model = tf.keras.models.load_model(
model_path,
custom_objects={"DistilBertLayer": DistilBertLayer},
compile=False
)
tokenizer = AutoTokenizer.from_pretrained("vinai/bertweet-base")
def predict(text):
inputs = tokenizer(
[text],
max_length=50,
truncation=True,
padding="max_length",
return_tensors="tf"
)
input_ids = inputs["input_ids"]
attention_mask = inputs["attention_mask"]
# si tu salida es (1,) sigmoid:
prob = model.predict([input_ids, attention_mask])[0][0]
pred = bool(prob > 0.5)
return {"prob": float(prob), "pred": pred}
demo = gr.Interface(
fn=predict,
inputs=gr.Textbox(lines=3, label="Tweet"),
outputs=gr.JSON(label="Result"),
title="Tweet classifier",
description="Paste a tweet in English"
)
if __name__ == "__main__":
# En Spaces NO uses share=True
demo.launch(ssr_mode=False)
|