sample_coder / src /streamlit_app.py
pradeep4321's picture
Update src/streamlit_app.py
d73111a verified
raw
history blame
2.68 kB
import streamlit as st
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
# ==============================
# πŸ” HIDE STREAMLIT MENU
# ==============================
st.markdown("""
<style>
#MainMenu {visibility: hidden;}
header {visibility: hidden;}
footer {visibility: hidden;}
.stDeployButton {display:none;}
</style>
""", unsafe_allow_html=True)
# ==============================
# PAGE CONFIG
# ==============================
st.set_page_config(page_title="πŸ’» AI Code Generator", layout="wide")
# ==============================
# LOAD MODEL
# ==============================
@st.cache_resource
def load_model():
model_name = "google/codegemma-2b"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float32,
device_map="cpu"
)
return tokenizer, model
tokenizer, model = load_model()
# ==============================
# CLEAN OUTPUT (IMPORTANT FIX)
# ==============================
def extract_code(text):
# Try to extract code block if exists
if "```" in text:
parts = text.split("```")
if len(parts) >= 2:
return parts[1].strip()
return text.strip()
# ==============================
# GENERATE CODE (SIMPLIFIED PROMPT)
# ==============================
def generate_code(prompt, language):
full_prompt = f"""
Write a {language} function for the following task:
{prompt}
Only return code.
"""
inputs = tokenizer(full_prompt, return_tensors="pt")
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=120,
do_sample=False,
temperature=0.0
)
result = tokenizer.decode(outputs[0], skip_special_tokens=True)
result = result.replace(full_prompt, "").strip()
return extract_code(result)
# ==============================
# UI
# ==============================
st.title("πŸ’» AI Code Generator (Fast & Accurate)")
col1, col2 = st.columns(2)
with col1:
user_prompt = st.text_area("Describe your task", height=200)
with col2:
language = st.selectbox(
"Select Programming Language",
["Python", "JavaScript", "SQL", "Java", "C++", "HTML", "CSS"]
)
# ==============================
# BUTTON
# ==============================
if st.button("Generate Code"):
if not user_prompt.strip():
st.warning("⚠️ Please enter a task")
else:
with st.spinner("⚑ Generating clean code..."):
code = generate_code(user_prompt, language)
st.success("βœ… Generated Code")
st.code(code, language=language.lower())