test / app.py
CaioMartins1's picture
Update app.py
604235c
import torch
import gradio as gr
from transformers import pipeline, BertTokenizer, BertForQuestionAnswering
from datasets import load_dataset
# Load the dataset
advice_dataset = load_dataset("ziq/depression_advice")
# Load the fine-tuned BERT model and tokenizer
model_dir = "./bert-finetuned-depression"
model = BertForQuestionAnswering.from_pretrained(model_dir)
tokenizer = BertTokenizer.from_pretrained(model_dir)
# Extract context and messages
contexts = advice_dataset["train"]["text"]
# Define a function to generate answers
def generate_answer(messages):
# If messages is a list, use the first message
if isinstance(messages, list):
messages = messages[0]
# Tokenize the input message
inputs = tokenizer(messages, return_tensors="pt")
# Use the fine-tuned BERT model to generate the answer for the single message
with torch.no_grad():
outputs = model(**inputs)
# Decode the output and return the answer
answer_start = torch.argmax(outputs.start_logits)
answer_end = torch.argmax(outputs.end_logits) + 1
answer = tokenizer.decode(inputs["input_ids"][0][answer_start:answer_end])
return answer if answer else "No answer found."
# Create a Gradio interface
iface = gr.Interface(
fn=generate_answer,
inputs=[
gr.Textbox(type="text", label="Message"),
],
outputs=gr.Textbox(type="text", label="Answer"),
title="Depression Advice Generator",
description="Enter your feelings, and get supportive advice generated by a fine-tuned BERT model.",
)
# Launch the interface
iface.launch()