Astro_gpt2 / app.py
Branis333's picture
Update app.py
5ec7797 verified
import gradio as gr
from transformers import GPT2Tokenizer, GPT2LMHeadModel
import torch
print("Loading model...")
# Load your model from Hugging Face
MODEL_NAME = "Branis333/astro-gpt2-chatbot"
tokenizer = GPT2Tokenizer.from_pretrained(MODEL_NAME)
model = GPT2LMHeadModel.from_pretrained(MODEL_NAME)
# Set model to evaluation mode
model.eval()
print("Model loaded successfully!")
def format_question(question):
"""
Automatically add question mark if not present.
Args:
question (str): The input question
Returns:
str: Question with proper punctuation
"""
question = question.strip()
# Check if question already ends with ?, !, or .
if not question.endswith(('?', '!', '.')):
question = question + '?'
return question
def answer_astronomy_question(question, max_length=150, temperature=0.7, top_p=0.9):
"""Generate an answer to an astronomy question."""
# Automatically format question with ? if needed
formatted_question = format_question(question)
# Format the prompt
prompt = f"Q: {formatted_question}\nA:"
# Tokenize input
inputs = tokenizer(prompt, return_tensors="pt")
# Generate response
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=max_length,
temperature=temperature,
top_p=top_p,
do_sample=True,
pad_token_id=tokenizer.eos_token_id,
eos_token_id=tokenizer.eos_token_id,
repetition_penalty=1.2,
)
# Decode the output
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
# Extract just the answer part
if "A:" in generated_text:
answer = generated_text.split("A:", 1)[1].strip()
else:
answer = generated_text.strip()
return answer
# Examples with ALL 4 parameters
examples = [
["What is a black hole?", 150, 0.7, 0.9],
["What is a constellation?", 150, 0.7, 0.9],
["What causes auroras on Earth?", 150, 0.7, 0.9],
["Explain the difference between a planet and a star.", 200, 0.7, 0.9],
["What is the Big Bang theory?", 200, 0.8, 0.9],
]
# Create Gradio interface
interface = gr.Interface(
fn=answer_astronomy_question,
inputs=[
gr.Textbox(
label="Ask an Astronomy Question",
placeholder="e.g., What is a black hole (question mark is optional)",
lines=2
),
gr.Slider(
minimum=50,
maximum=300,
value=150,
step=10,
label="Max Answer Length"
),
gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.7,
step=0.1,
label="Temperature (creativity)"
),
gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.9,
step=0.1,
label="Top-p (diversity)"
),
],
outputs=gr.Textbox(
label="Answer",
lines=8
),
examples=examples,
title="🌌 Astronomy GPT-2 Chatbot",
description="""
Ask questions about astronomy and space science! This chatbot is powered by a fine-tuned GPT-2 model
trained on 2,736 astronomy Q&A pairs.
**Tip:** You don't need to add a question mark - it will be added automatically! ✨
**Note:** This is an educational tool. Always verify important astronomical facts with authoritative sources.
""",
article="""
### About This Model
- **Base Model:** GPT-2
- **Training Data:** 2,736 cleaned astronomy Q&A pairs
- **Perplexity:** 1.61
- **Specialization:** Astronomy terminology, concepts, and phenomena
### Tips for Best Results:
- Ask specific, clear questions (question mark optional!)
- Lower temperature = more focused answers
- Higher temperature = more creative answers
### Model Repository
[View on Hugging Face](https://huggingface.co/Branis333/astro-gpt2-chatbot)
""",
theme=gr.themes.Soft(),
)
# Launch the app
if __name__ == "__main__":
interface.launch()