smartquote / app.py
dnzita's picture
Update app.py
c7ee385 verified
import torch
import gradio as gr
from transformers import AutoTokenizer, AutoModel
from sentence_transformers import SentenceTransformer
import numpy as np
# --- Carregando os modelos ---
mpnet_model = SentenceTransformer("sentence-transformers/paraphrase-multilingual-mpnet-base-v2")
bert_tokenizer = AutoTokenizer.from_pretrained("neuralmind/bert-base-portuguese-cased")
bert_model = AutoModel.from_pretrained("neuralmind/bert-base-portuguese-cased")
# Função auxiliar para pooling no BERTimbau
def mean_pooling(model_output, attention_mask):
token_embeddings = model_output[0] # [batch_size, seq_len, hidden_size]
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
# --- Funções de embeddings ---
def embed_text(texts, model_choice):
if isinstance(texts, str):
texts = [texts]
if model_choice == "mpnet":
embeddings = mpnet_model.encode(texts, convert_to_numpy=True, normalize_embeddings=True)
elif model_choice == "bertimbau":
encoded = bert_tokenizer(texts, padding=True, truncation=True, return_tensors="pt")
with torch.no_grad():
model_out = bert_model(**encoded)
pooled = mean_pooling(model_out, encoded["attention_mask"])
embeddings = torch.nn.functional.normalize(pooled, p=2, dim=1).cpu().numpy()
else:
raise ValueError("Escolha inválida de modelo: use 'mpnet' ou 'bertimbau'.")
return [emb.tolist() for emb in embeddings]
# --- Interface Gradio ---
demo = gr.Interface(
fn=embed_text,
inputs=[
gr.Textbox(label="Texto ou lista de textos", lines=3, placeholder="Digite aqui..."),
gr.Radio(choices=["mpnet", "bertimbau"], label="Modelo", value="mpnet")
],
outputs=gr.JSON(label="Embeddings (768 dimensões)"),
title="Embeddings API - BERTimbau & MPNet",
description="Gere embeddings de frases em Português (BERTimbau) ou multilíngues (MPNet)."
)
if __name__ == "__main__":
demo.launch()