Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| from datetime import datetime | |
| # Custom CSS for UI | |
| st.markdown(""" | |
| <style> | |
| .main { background-color: #f9f9f9; padding: 20px; } | |
| .stTextArea textarea { | |
| border: 1px solid #ddd; | |
| border-radius: 8px; | |
| padding: 10px; | |
| font-family: 'Roboto', sans-serif; | |
| font-size: 16px; | |
| background-color: #fff; | |
| box-shadow: 0 2px 4px rgba(0,0,0,0.1); | |
| } | |
| .stButton button { | |
| background-color: #4a90e2; | |
| color: white; | |
| border-radius: 8px; | |
| padding: 10px 20px; | |
| font-family: 'Roboto', sans-serif; | |
| font-size: 14px; | |
| } | |
| .stButton button:hover { | |
| background-color: #357abd; | |
| } | |
| .code-output { | |
| background-color: #2b2b2b; | |
| color: #f0f0f0; | |
| padding: 15px; | |
| border-radius: 8px; | |
| font-family: 'Courier New', monospace; | |
| font-size: 14px; | |
| margin-top: 10px; | |
| } | |
| .title { | |
| font-family: 'Roboto', sans-serif; | |
| font-size: 28px; | |
| font-weight: bold; | |
| color: #333; | |
| margin-bottom: 10px; | |
| } | |
| .subtitle { | |
| font-family: 'Roboto', sans-serif; | |
| font-size: 16px; | |
| color: #666; | |
| margin-bottom: 20px; | |
| } | |
| .chat-message { | |
| font-family: 'Roboto', sans-serif; | |
| font-size: 16px; | |
| color: #333; | |
| margin-bottom: 5px; | |
| } | |
| </style> | |
| """, unsafe_allow_html=True) | |
| # Cache model and tokenizer to avoid reloading | |
| def load_model_and_tokenizer(): | |
| checkpoint = "Salesforce/codegen-350M-mono" | |
| try: | |
| st.write("Loading tokenizer...") | |
| tokenizer = AutoTokenizer.from_pretrained(checkpoint) | |
| st.write("Loading model...") | |
| model = AutoModelForCausalLM.from_pretrained(checkpoint) | |
| st.write("Model and tokenizer loaded successfully!") | |
| return tokenizer, model | |
| except Exception as e: | |
| st.error(f"Failed to load model/tokenizer: {e}") | |
| return None, None | |
| # Load model and tokenizer once | |
| tokenizer, model = load_model_and_tokenizer() | |
| if tokenizer is None or model is None: | |
| st.stop() | |
| # Function to generate code | |
| def generate_code(description): | |
| prompt = f"Generate Python code for the following task: {description}\n" | |
| inputs = tokenizer(prompt, return_tensors="pt") | |
| try: | |
| outputs = model.generate( | |
| **inputs, | |
| max_length=500, | |
| num_return_sequences=1, | |
| pad_token_id=tokenizer.eos_token_id | |
| ) | |
| code = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| return code[len(prompt):].strip() | |
| except Exception as e: | |
| st.error(f"Error generating code: {e}") | |
| return "Error: Could not generate code." | |
| # Initialize chat history | |
| if "chat_history" not in st.session_state: | |
| st.session_state.chat_history = [] | |
| # UI Layout | |
| st.markdown('<div class="title">Code Generation Bot</div>', unsafe_allow_html=True) | |
| st.markdown('<div class="subtitle">Describe your task, and I’ll generate Python code for you!</div>', unsafe_allow_html=True) | |
| with st.container(): | |
| # Input area | |
| description = st.text_area( | |
| "Enter your description here", | |
| placeholder="e.g., Write a function to calculate the factorial of a number", | |
| height=150 | |
| ) | |
| col1, col2 = st.columns([1, 1]) | |
| with col1: | |
| if st.button("Generate"): | |
| if description.strip(): | |
| with st.spinner("Thinking..."): | |
| generated_code = generate_code(description) | |
| st.session_state.chat_history.append({ | |
| "input": description, | |
| "output": generated_code, | |
| "time": datetime.now().strftime("%H:%M:%S") | |
| }) | |
| else: | |
| st.warning("Please enter a description first!") | |
| with col2: | |
| if st.button("Clear History"): | |
| st.session_state.chat_history = [] | |
| st.success("Chat history cleared!") | |
| # Display chat history | |
| if st.session_state.chat_history: | |
| st.write("### Chat History") | |
| for chat in st.session_state.chat_history: | |
| st.markdown(f'<div class="chat-message"><strong>You ({chat["time"]}):</strong> {chat["input"]}</div>', unsafe_allow_html=True) | |
| st.markdown(f'<div class="code-output">{chat["output"]}</div>', unsafe_allow_html=True) | |
| st.markdown("---") | |
| st.info("Tip: Check the generated code for accuracy before using it!") |