Hokeno's picture
Update app.py
cf1e806 verified
import gradio as gr
import torch
from transformers import GPT2Tokenizer, GPT2LMHeadModel
import os
MODEL_DIR = "./gpt2-finetuned-ai-ethics-final"
try:
tokenizer = GPT2Tokenizer.from_pretrained(MODEL_DIR)
if tokenizer.pad_token is None:
tokenizer.add_special_tokens({'pad_token': '[PAD]'})
model = GPT2LMHeadModel.from_pretrained(MODEL_DIR)
model.config.pad_token_id = tokenizer.pad_token_id
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model.eval()
print(f"Model and tokenizer successfully loaded from {MODEL_DIR} to {device}.")
except Exception as e:
print(f"Error loading model or tokenizer: {e}")
print("Make sure you have run the fine-tuning process and the model is saved in the correct directory.")
exit()
def generate_text(prompt, max_length=100, temperature=0.7, top_k=50, top_p=0.95, no_repeat_ngram_size=2):
if not prompt:
return "Enter prompt here."
try:
input_ids = tokenizer.encode(prompt, return_tensors='pt').to(device)
output = model.generate(
input_ids,
max_length=max_length,
num_return_sequences=1,
no_repeat_ngram_size=no_repeat_ngram_size,
top_k=top_k,
top_p=top_p,
temperature=temperature,
pad_token_id=tokenizer.pad_token_id
)
generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
last_period_index = generated_text.rfind('.')
if last_period_index != -1:
generated_text = generated_text[:last_period_index + 1]
return generated_text
except Exception as e:
return f"An error occurred while generating text: {e}"
iface = gr.Interface(
fn=generate_text,
inputs=[
gr.Textbox(lines=5, label="Enter your prompt", placeholder="Example: The ethical implications of AI"),
gr.Slider(minimum=50, maximum=300, value=100, label="Maximum Text Length"),
gr.Slider(minimum=0.1, maximum=1.0, value=0.7, label="Temperature (Randomness)"),
gr.Slider(minimum=0, maximum=100, value=50, step=1, label="Top-K (Word Restriction)"),
gr.Slider(minimum=0.0, maximum=1.0, value=0.95, label="Top-P (Cumulative Probability)"),
gr.Slider(minimum=1, maximum=5, value=2, step=1, label="N-Gram Size Without Repetition")
],
outputs=gr.Textbox(label="Generated Text", lines=10),
title="AI Ethical Text Generation Application (GPT-2 Fine-tuned)",
description="Enter a prompt and the fine-tuned GPT-2 model will generate text related to AI ethics.",
theme="soft"
)
if __name__ == "__main__":
print("Launching the Gradio app...")
iface.launch(share=False)