Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| from transformers import GPT2Tokenizer, GPT2LMHeadModel | |
| import torch | |
| print("Loading model...") | |
| # Load your model from Hugging Face | |
| MODEL_NAME = "Branis333/astro-gpt2-chatbot" | |
| tokenizer = GPT2Tokenizer.from_pretrained(MODEL_NAME) | |
| model = GPT2LMHeadModel.from_pretrained(MODEL_NAME) | |
| # Set model to evaluation mode | |
| model.eval() | |
| print("Model loaded successfully!") | |
| def format_question(question): | |
| """ | |
| Automatically add question mark if not present. | |
| Args: | |
| question (str): The input question | |
| Returns: | |
| str: Question with proper punctuation | |
| """ | |
| question = question.strip() | |
| # Check if question already ends with ?, !, or . | |
| if not question.endswith(('?', '!', '.')): | |
| question = question + '?' | |
| return question | |
| def answer_astronomy_question(question, max_length=150, temperature=0.7, top_p=0.9): | |
| """Generate an answer to an astronomy question.""" | |
| # Automatically format question with ? if needed | |
| formatted_question = format_question(question) | |
| # Format the prompt | |
| prompt = f"Q: {formatted_question}\nA:" | |
| # Tokenize input | |
| inputs = tokenizer(prompt, return_tensors="pt") | |
| # Generate response | |
| with torch.no_grad(): | |
| outputs = model.generate( | |
| **inputs, | |
| max_new_tokens=max_length, | |
| temperature=temperature, | |
| top_p=top_p, | |
| do_sample=True, | |
| pad_token_id=tokenizer.eos_token_id, | |
| eos_token_id=tokenizer.eos_token_id, | |
| repetition_penalty=1.2, | |
| ) | |
| # Decode the output | |
| generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| # Extract just the answer part | |
| if "A:" in generated_text: | |
| answer = generated_text.split("A:", 1)[1].strip() | |
| else: | |
| answer = generated_text.strip() | |
| return answer | |
| # Examples with ALL 4 parameters | |
| examples = [ | |
| ["What is a black hole?", 150, 0.7, 0.9], | |
| ["What is a constellation?", 150, 0.7, 0.9], | |
| ["What causes auroras on Earth?", 150, 0.7, 0.9], | |
| ["Explain the difference between a planet and a star.", 200, 0.7, 0.9], | |
| ["What is the Big Bang theory?", 200, 0.8, 0.9], | |
| ] | |
| # Create Gradio interface | |
| interface = gr.Interface( | |
| fn=answer_astronomy_question, | |
| inputs=[ | |
| gr.Textbox( | |
| label="Ask an Astronomy Question", | |
| placeholder="e.g., What is a black hole (question mark is optional)", | |
| lines=2 | |
| ), | |
| gr.Slider( | |
| minimum=50, | |
| maximum=300, | |
| value=150, | |
| step=10, | |
| label="Max Answer Length" | |
| ), | |
| gr.Slider( | |
| minimum=0.1, | |
| maximum=1.0, | |
| value=0.7, | |
| step=0.1, | |
| label="Temperature (creativity)" | |
| ), | |
| gr.Slider( | |
| minimum=0.1, | |
| maximum=1.0, | |
| value=0.9, | |
| step=0.1, | |
| label="Top-p (diversity)" | |
| ), | |
| ], | |
| outputs=gr.Textbox( | |
| label="Answer", | |
| lines=8 | |
| ), | |
| examples=examples, | |
| title="π Astronomy GPT-2 Chatbot", | |
| description=""" | |
| Ask questions about astronomy and space science! This chatbot is powered by a fine-tuned GPT-2 model | |
| trained on 2,736 astronomy Q&A pairs. | |
| **Tip:** You don't need to add a question mark - it will be added automatically! β¨ | |
| **Note:** This is an educational tool. Always verify important astronomical facts with authoritative sources. | |
| """, | |
| article=""" | |
| ### About This Model | |
| - **Base Model:** GPT-2 | |
| - **Training Data:** 2,736 cleaned astronomy Q&A pairs | |
| - **Perplexity:** 1.61 | |
| - **Specialization:** Astronomy terminology, concepts, and phenomena | |
| ### Tips for Best Results: | |
| - Ask specific, clear questions (question mark optional!) | |
| - Lower temperature = more focused answers | |
| - Higher temperature = more creative answers | |
| ### Model Repository | |
| [View on Hugging Face](https://huggingface.co/Branis333/astro-gpt2-chatbot) | |
| """, | |
| theme=gr.themes.Soft(), | |
| ) | |
| # Launch the app | |
| if __name__ == "__main__": | |
| interface.launch() |