File size: 2,771 Bytes
a486c98
 
 
 
 
 
 
 
 
 
 
 
2b04a5d
 
 
a486c98
 
 
 
2b04a5d
 
a486c98
 
2b04a5d
 
a486c98
 
 
 
2b04a5d
a486c98
 
 
 
 
 
2b04a5d
a486c98
 
 
 
 
2b04a5d
a486c98
 
 
 
2b04a5d
 
 
a486c98
2b04a5d
a486c98
 
2b04a5d
a486c98
 
 
 
2b04a5d
 
 
 
 
 
a486c98
2b04a5d
 
 
cf1e806
a486c98
 
 
2b04a5d
65b8c83
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import gradio as gr
import torch
from transformers import GPT2Tokenizer, GPT2LMHeadModel
import os

MODEL_DIR = "./gpt2-finetuned-ai-ethics-final"

try:
    tokenizer = GPT2Tokenizer.from_pretrained(MODEL_DIR)
    
    if tokenizer.pad_token is None:
        tokenizer.add_special_tokens({'pad_token': '[PAD]'})
    
    model = GPT2LMHeadModel.from_pretrained(MODEL_DIR)
    
    model.config.pad_token_id = tokenizer.pad_token_id

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    model.eval() 
    print(f"Model and tokenizer successfully loaded from {MODEL_DIR} to {device}.")

except Exception as e:
    print(f"Error loading model or tokenizer: {e}")
    print("Make sure you have run the fine-tuning process and the model is saved in the correct directory.")
    exit()

def generate_text(prompt, max_length=100, temperature=0.7, top_k=50, top_p=0.95, no_repeat_ngram_size=2):
    if not prompt:
        return "Enter prompt here."

    try:
        input_ids = tokenizer.encode(prompt, return_tensors='pt').to(device)

        output = model.generate(
            input_ids,
            max_length=max_length,
            num_return_sequences=1,
            no_repeat_ngram_size=no_repeat_ngram_size,
            top_k=top_k,
            top_p=top_p,
            temperature=temperature,
            pad_token_id=tokenizer.pad_token_id
        )

        generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
        
        last_period_index = generated_text.rfind('.')
        if last_period_index != -1:
            generated_text = generated_text[:last_period_index + 1]
        
        return generated_text

    except Exception as e:
        return f"An error occurred while generating text: {e}"

iface = gr.Interface(
    fn=generate_text,
    inputs=[
        gr.Textbox(lines=5, label="Enter your prompt", placeholder="Example: The ethical implications of AI"),
        gr.Slider(minimum=50, maximum=300, value=100, label="Maximum Text Length"),
        gr.Slider(minimum=0.1, maximum=1.0, value=0.7, label="Temperature (Randomness)"),
        gr.Slider(minimum=0, maximum=100, value=50, step=1, label="Top-K (Word Restriction)"),
        gr.Slider(minimum=0.0, maximum=1.0, value=0.95, label="Top-P (Cumulative Probability)"),
        gr.Slider(minimum=1, maximum=5, value=2, step=1, label="N-Gram Size Without Repetition")
    ],
    outputs=gr.Textbox(label="Generated Text", lines=10),
    title="AI Ethical Text Generation Application (GPT-2 Fine-tuned)",
    description="Enter a prompt and the fine-tuned GPT-2 model will generate text related to AI ethics.",
    theme="soft"
)

if __name__ == "__main__":
    print("Launching the Gradio app...")
    iface.launch(share=False)