import gradio as gr import logging # from transformers import GPTJForCausalLM, GPT2Tokenizer # # Load the GPT-J model and tokenizer # model_name = "EleutherAI/gpt-j-6B" # tokenizer = GPT2Tokenizer.from_pretrained(model_name) # model = GPTJForCausalLM.from_pretrained(model_name) from transformers import GPT2LMHeadModel, GPT2Tokenizer # Load GPT-2 model and tokenizer #model_name = "gpt2" # You can use "gpt2-medium" or "gpt2-large" for more power model_name = "../custom_model/custom_gpt2" tokenizer = GPT2Tokenizer.from_pretrained(model_name) model = GPT2LMHeadModel.from_pretrained(model_name) model.config.eos_token_id = tokenizer.eos_token_id tokenizer.pad_token = tokenizer.eos_token # Set the pad_token to eos_token # Function to generate text based on the user input def generate_text(prompt): # Tokenizing the input inputs = tokenizer(prompt, return_tensors="pt", truncation=False, padding=False, max_length=512) # Generate output outputs = model.generate(inputs['input_ids'],max_length = 150, num_return_sequences=1, no_repeat_ngram_size=2, top_k=50, top_p=0.95, eos_token_id=model.config.eos_token_id ) # Decode the output generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) return generated_text.strip() # Gradio interface setup iface = gr.Interface(fn=generate_text, inputs=gr.inputs.Textbox(lines=10, placeholder="Enter your prompt here..."), outputs="text") # Launch the Gradio interface iface.launch(share=True)