Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| import torch | |
| # Set page configuration | |
| st.set_page_config( | |
| page_title="GPT-2 Code Generator", | |
| page_icon="π»", | |
| layout="wide" | |
| ) | |
| # Title and description | |
| st.title("π» GPT-2 Code Generation Model Tester") | |
| st.markdown( | |
| f"Testing model: **[ErikDaska/lr_5e-05](https://huggingface.co/ErikDaska/lr_5e-05)**" | |
| ) | |
| st.write("Enter a prompt or a partial function definition below to see how your model completes it.") | |
| # Cache the model and tokenizer loading so it doesn't reload on every button press | |
| def load_model(): | |
| model_name = "ErikDaska/lr_5e-05" | |
| with st.spinner("Loading model and tokenizer from Hugging Face... This might take a minute."): | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| model = AutoModelForCausalLM.from_pretrained(model_name) | |
| # Check if GPU is available | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| model.to(device) | |
| return tokenizer, model, device | |
| try: | |
| tokenizer, model, device = load_model() | |
| st.success("Model loaded successfully!", icon="β ") | |
| except Exception as e: | |
| st.error(f"Error loading model: {e}") | |
| st.stop() | |
| # Sidebar for generation parameters | |
| st.sidebar.header("Generation Settings") | |
| max_length = st.sidebar.slider("Max Length", min_value=10, max_value=512, value=128, step=10) | |
| temperature = st.sidebar.slider("Temperature (Creativity)", min_value=0.1, max_value=1.5, value=0.7, step=0.1) | |
| top_p = st.sidebar.slider("Top-p (Nucleus Sampling)", min_value=0.0, max_value=1.0, value=0.9, step=0.05) | |
| do_sample = st.sidebar.checkbox("Use Sampling", value=True) | |
| # Main UI text input | |
| prompt = st.text_area( | |
| "Enter Code Prompt:", | |
| value="def calculate_factorial(n):\n # This function calculates the factorial of a number", | |
| height=150 | |
| ) | |
| # Generation trigger | |
| if st.button("Generate Code", type="primary"): | |
| if not prompt.strip(): | |
| st.warning("Please enter a prompt first.") | |
| else: | |
| with st.spinner("Generating code..."): | |
| # Encode input | |
| inputs = tokenizer(prompt, return_tensors="pt").to(device) | |
| # Generate tokens | |
| with torch.no_grad(): | |
| output_sequences = model.generate( | |
| input_ids=inputs["input_ids"], | |
| attention_mask=inputs.get("attention_mask"), | |
| max_length=max_length + len(inputs["input_ids"][0]), | |
| temperature=temperature, | |
| top_p=top_p, | |
| do_sample=do_sample, | |
| pad_token_id=tokenizer.eos_token_id | |
| ) | |
| # Decode output | |
| generated_code = tokenizer.decode(output_sequences[0], skip_special_tokens=True) | |
| # Display results | |
| st.subheader("Generated Output:") | |
| st.code(generated_code, language="python") |