Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| from transformers import GPT2Tokenizer, GPT2LMHeadModel | |
| import os | |
| MODEL_DIR = "./gpt2-finetuned-ai-ethics-final" | |
| try: | |
| tokenizer = GPT2Tokenizer.from_pretrained(MODEL_DIR) | |
| if tokenizer.pad_token is None: | |
| tokenizer.add_special_tokens({'pad_token': '[PAD]'}) | |
| model = GPT2LMHeadModel.from_pretrained(MODEL_DIR) | |
| model.config.pad_token_id = tokenizer.pad_token_id | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| model.to(device) | |
| model.eval() | |
| print(f"Model and tokenizer successfully loaded from {MODEL_DIR} to {device}.") | |
| except Exception as e: | |
| print(f"Error loading model or tokenizer: {e}") | |
| print("Make sure you have run the fine-tuning process and the model is saved in the correct directory.") | |
| exit() | |
| def generate_text(prompt, max_length=100, temperature=0.7, top_k=50, top_p=0.95, no_repeat_ngram_size=2): | |
| if not prompt: | |
| return "Enter prompt here." | |
| try: | |
| input_ids = tokenizer.encode(prompt, return_tensors='pt').to(device) | |
| output = model.generate( | |
| input_ids, | |
| max_length=max_length, | |
| num_return_sequences=1, | |
| no_repeat_ngram_size=no_repeat_ngram_size, | |
| top_k=top_k, | |
| top_p=top_p, | |
| temperature=temperature, | |
| pad_token_id=tokenizer.pad_token_id | |
| ) | |
| generated_text = tokenizer.decode(output[0], skip_special_tokens=True) | |
| last_period_index = generated_text.rfind('.') | |
| if last_period_index != -1: | |
| generated_text = generated_text[:last_period_index + 1] | |
| return generated_text | |
| except Exception as e: | |
| return f"An error occurred while generating text: {e}" | |
| iface = gr.Interface( | |
| fn=generate_text, | |
| inputs=[ | |
| gr.Textbox(lines=5, label="Enter your prompt", placeholder="Example: The ethical implications of AI"), | |
| gr.Slider(minimum=50, maximum=300, value=100, label="Maximum Text Length"), | |
| gr.Slider(minimum=0.1, maximum=1.0, value=0.7, label="Temperature (Randomness)"), | |
| gr.Slider(minimum=0, maximum=100, value=50, step=1, label="Top-K (Word Restriction)"), | |
| gr.Slider(minimum=0.0, maximum=1.0, value=0.95, label="Top-P (Cumulative Probability)"), | |
| gr.Slider(minimum=1, maximum=5, value=2, step=1, label="N-Gram Size Without Repetition") | |
| ], | |
| outputs=gr.Textbox(label="Generated Text", lines=10), | |
| title="AI Ethical Text Generation Application (GPT-2 Fine-tuned)", | |
| description="Enter a prompt and the fine-tuned GPT-2 model will generate text related to AI ethics.", | |
| theme="soft" | |
| ) | |
| if __name__ == "__main__": | |
| print("Launching the Gradio app...") | |
| iface.launch(share=False) |