import gradio as gr from transformers import GPT2Tokenizer, GPT2LMHeadModel import torch print("Loading model...") # Load your model from Hugging Face MODEL_NAME = "Branis333/astro-gpt2-chatbot" tokenizer = GPT2Tokenizer.from_pretrained(MODEL_NAME) model = GPT2LMHeadModel.from_pretrained(MODEL_NAME) # Set model to evaluation mode model.eval() print("Model loaded successfully!") def format_question(question): """ Automatically add question mark if not present. Args: question (str): The input question Returns: str: Question with proper punctuation """ question = question.strip() # Check if question already ends with ?, !, or . if not question.endswith(('?', '!', '.')): question = question + '?' return question def answer_astronomy_question(question, max_length=150, temperature=0.7, top_p=0.9): """Generate an answer to an astronomy question.""" # Automatically format question with ? if needed formatted_question = format_question(question) # Format the prompt prompt = f"Q: {formatted_question}\nA:" # Tokenize input inputs = tokenizer(prompt, return_tensors="pt") # Generate response with torch.no_grad(): outputs = model.generate( **inputs, max_new_tokens=max_length, temperature=temperature, top_p=top_p, do_sample=True, pad_token_id=tokenizer.eos_token_id, eos_token_id=tokenizer.eos_token_id, repetition_penalty=1.2, ) # Decode the output generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) # Extract just the answer part if "A:" in generated_text: answer = generated_text.split("A:", 1)[1].strip() else: answer = generated_text.strip() return answer # Examples with ALL 4 parameters examples = [ ["What is a black hole?", 150, 0.7, 0.9], ["What is a constellation?", 150, 0.7, 0.9], ["What causes auroras on Earth?", 150, 0.7, 0.9], ["Explain the difference between a planet and a star.", 200, 0.7, 0.9], ["What is the Big Bang theory?", 200, 0.8, 0.9], ] # Create Gradio interface interface = gr.Interface( fn=answer_astronomy_question, inputs=[ gr.Textbox( label="Ask an Astronomy Question", placeholder="e.g., What is a black hole (question mark is optional)", lines=2 ), gr.Slider( minimum=50, maximum=300, value=150, step=10, label="Max Answer Length" ), gr.Slider( minimum=0.1, maximum=1.0, value=0.7, step=0.1, label="Temperature (creativity)" ), gr.Slider( minimum=0.1, maximum=1.0, value=0.9, step=0.1, label="Top-p (diversity)" ), ], outputs=gr.Textbox( label="Answer", lines=8 ), examples=examples, title="🌌 Astronomy GPT-2 Chatbot", description=""" Ask questions about astronomy and space science! This chatbot is powered by a fine-tuned GPT-2 model trained on 2,736 astronomy Q&A pairs. **Tip:** You don't need to add a question mark - it will be added automatically! ✨ **Note:** This is an educational tool. Always verify important astronomical facts with authoritative sources. """, article=""" ### About This Model - **Base Model:** GPT-2 - **Training Data:** 2,736 cleaned astronomy Q&A pairs - **Perplexity:** 1.61 - **Specialization:** Astronomy terminology, concepts, and phenomena ### Tips for Best Results: - Ask specific, clear questions (question mark optional!) - Lower temperature = more focused answers - Higher temperature = more creative answers ### Model Repository [View on Hugging Face](https://huggingface.co/Branis333/astro-gpt2-chatbot) """, theme=gr.themes.Soft(), ) # Launch the app if __name__ == "__main__": interface.launch()