Spaces:
Build error
Build error
| import streamlit as st | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| # Load pre-trained model and tokenizer from the checkpoint | |
| model_name = "flutter-code-generator/flutter_codegen_model/checkpoint-1500" | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| model = AutoModelForCausalLM.from_pretrained(model_name) | |
| model = AutoModelForCausalLM.from_pretrained(model_name) | |
| # Function to clean up repetitive lines in code | |
| def clean_code_response(response): | |
| lines = response.splitlines() | |
| unique_lines = [] | |
| for line in lines: | |
| if line.strip() not in unique_lines: # Avoid duplicates | |
| unique_lines.append(line.strip()) | |
| return "\n".join(unique_lines) | |
| # Function to generate Flutter code | |
| def generate_flutter_code(prompt, temperature, top_p, max_length, num_return_sequences, repetition_penalty, top_k): | |
| inputs = tokenizer( | |
| prompt, | |
| return_tensors="pt", | |
| padding=True, | |
| truncation=True, | |
| ) | |
| outputs = model.generate( | |
| inputs["input_ids"], | |
| max_length=max_length, | |
| num_return_sequences=num_return_sequences, | |
| temperature=temperature, | |
| top_p=top_p, | |
| top_k=top_k, | |
| repetition_penalty=repetition_penalty, | |
| do_sample=True, | |
| pad_token_id=tokenizer.pad_token_id, | |
| ) | |
| code = [tokenizer.decode(output, skip_special_tokens=True) for output in outputs] | |
| return [clean_code_response(c) for c in code] | |
| # App Title | |
| st.title("Flutter Code Generator") | |
| # Default parameter values | |
| DEFAULT_TEMPERATURE = 0.7 | |
| DEFAULT_TOP_P = 0.9 | |
| DEFAULT_MAX_LENGTH = 512 | |
| DEFAULT_NUM_RETURN_SEQUENCES = 1 | |
| DEFAULT_REPETITION_PENALTY = 1.2 | |
| DEFAULT_TOP_K = 50 | |
| # Sidebar for settings | |
| st.sidebar.title("Generation Settings") | |
| temperature = st.sidebar.slider( | |
| "Temperature (randomness)", | |
| 0.1, 1.0, DEFAULT_TEMPERATURE, step=0.1, | |
| ) | |
| top_p = st.sidebar.slider( | |
| "Top-p (cumulative probability)", | |
| 0.1, 1.0, DEFAULT_TOP_P, step=0.1, | |
| ) | |
| max_length = st.sidebar.slider( | |
| "Max Output Length (tokens)", | |
| 128, 1024, DEFAULT_MAX_LENGTH, step=64, | |
| ) | |
| num_return_sequences = st.sidebar.slider( | |
| "Number of Outputs", | |
| 1, 5, DEFAULT_NUM_RETURN_SEQUENCES, | |
| ) | |
| repetition_penalty = st.sidebar.slider( | |
| "Repetition Penalty", | |
| 1.0, 2.0, DEFAULT_REPETITION_PENALTY, step=0.1, | |
| ) | |
| top_k = st.sidebar.slider( | |
| "Top-k (limit sampling pool)", | |
| 0, 100, DEFAULT_TOP_K, | |
| ) | |
| # Reset to defaults button | |
| if st.sidebar.button("Reset to Defaults"): | |
| st.session_state.update( | |
| { | |
| "temperature": DEFAULT_TEMPERATURE, | |
| "top_p": DEFAULT_TOP_P, | |
| "max_length": DEFAULT_MAX_LENGTH, | |
| "num_return_sequences": DEFAULT_NUM_RETURN_SEQUENCES, | |
| "repetition_penalty": DEFAULT_REPETITION_PENALTY, | |
| "top_k": DEFAULT_TOP_K, | |
| } | |
| ) | |
| # Input Section | |
| user_input = st.text_area( | |
| "Enter your prompt (e.g., 'Create a responsive login screen'):", | |
| max_chars=200, | |
| ) | |
| # Output Section | |
| if st.button("Generate Code"): | |
| if user_input.strip(): | |
| prompt = f"{user_input.strip()}" | |
| generated_code = generate_flutter_code( | |
| prompt, temperature, top_p, max_length, num_return_sequences, repetition_penalty, top_k | |
| ) | |
| for i, code in enumerate(generated_code, start=1): | |
| st.subheader(f"Output {i}") | |
| st.code(code, language="dart") | |
| else: | |
| st.error("Please enter a prompt before clicking 'Generate Code'.") | |