File size: 2,682 Bytes
c7f53d4
7087b82
 
c7f53d4
7087b82
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d73111a
7087b82
 
 
d144157
7087b82
45b1038
dd6b048
7087b82
 
d73111a
 
7087b82
 
 
 
 
 
 
d73111a
d144157
d73111a
 
 
 
 
 
d144157
 
 
 
d73111a
7087b82
 
 
d144157
d73111a
d144157
d73111a
d144157
d73111a
c7f53d4
 
45b1038
7087b82
d144157
 
 
d73111a
 
d144157
 
7087b82
 
 
d73111a
 
 
7087b82
 
 
 
45b1038
7087b82
 
 
 
 
 
 
 
 
 
 
 
5d9a347
d73111a
5d9a347
d144157
7087b82
d144157
7087b82
d73111a
d144157
5d9a347
d144157
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import streamlit as st
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

# ==============================
# 🔐 HIDE STREAMLIT MENU
# ==============================
st.markdown("""
<style>
#MainMenu {visibility: hidden;}
header {visibility: hidden;}
footer {visibility: hidden;}
.stDeployButton {display:none;}
</style>
""", unsafe_allow_html=True)

# ==============================
# PAGE CONFIG
# ==============================
st.set_page_config(page_title="💻 AI Code Generator", layout="wide")

# ==============================
# LOAD MODEL
# ==============================
@st.cache_resource
def load_model():
    model_name = "google/codegemma-2b"

    tokenizer = AutoTokenizer.from_pretrained(model_name)

    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        torch_dtype=torch.float32,
        device_map="cpu"
    )

    return tokenizer, model

tokenizer, model = load_model()

# ==============================
# CLEAN OUTPUT (IMPORTANT FIX)
# ==============================
def extract_code(text):
    # Try to extract code block if exists
    if "```" in text:
        parts = text.split("```")
        if len(parts) >= 2:
            return parts[1].strip()

    return text.strip()

# ==============================
# GENERATE CODE (SIMPLIFIED PROMPT)
# ==============================
def generate_code(prompt, language):

    full_prompt = f"""
Write a {language} function for the following task:

{prompt}

Only return code.
"""

    inputs = tokenizer(full_prompt, return_tensors="pt")

    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=120,
            do_sample=False,
            temperature=0.0
        )

    result = tokenizer.decode(outputs[0], skip_special_tokens=True)

    result = result.replace(full_prompt, "").strip()

    return extract_code(result)

# ==============================
# UI
# ==============================
st.title("💻 AI Code Generator (Fast & Accurate)")

col1, col2 = st.columns(2)

with col1:
    user_prompt = st.text_area("Describe your task", height=200)

with col2:
    language = st.selectbox(
        "Select Programming Language",
        ["Python", "JavaScript", "SQL", "Java", "C++", "HTML", "CSS"]
    )

# ==============================
# BUTTON
# ==============================
if st.button("Generate Code"):
    if not user_prompt.strip():
        st.warning("⚠️ Please enter a task")
    else:
        with st.spinner("⚡ Generating clean code..."):
            code = generate_code(user_prompt, language)

        st.success("✅ Generated Code")
        st.code(code, language=language.lower())