pradeep4321 commited on
Commit
7087b82
Β·
verified Β·
1 Parent(s): 5d1b5c1

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +89 -34
src/streamlit_app.py CHANGED
@@ -1,40 +1,95 @@
1
- import altair as alt
2
- import numpy as np
3
- import pandas as pd
4
  import streamlit as st
 
 
5
 
6
- """
7
- # Welcome to Streamlit!
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
- Edit `/streamlit_app.py` to customize this app to your heart's desire :heart:.
10
- If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
11
- forums](https://discuss.streamlit.io).
12
 
13
- In the meantime, below is an example of what you can do with just a few lines of code:
 
 
14
  """
15
 
16
- num_points = st.slider("Number of points in spiral", 1, 10000, 1100)
17
- num_turns = st.slider("Number of turns in spiral", 1, 300, 31)
18
-
19
- indices = np.linspace(0, 1, num_points)
20
- theta = 2 * np.pi * num_turns * indices
21
- radius = indices
22
-
23
- x = radius * np.cos(theta)
24
- y = radius * np.sin(theta)
25
-
26
- df = pd.DataFrame({
27
- "x": x,
28
- "y": y,
29
- "idx": indices,
30
- "rand": np.random.randn(num_points),
31
- })
32
-
33
- st.altair_chart(alt.Chart(df, height=700, width=700)
34
- .mark_point(filled=True)
35
- .encode(
36
- x=alt.X("x", axis=None),
37
- y=alt.Y("y", axis=None),
38
- color=alt.Color("idx", legend=None, scale=alt.Scale()),
39
- size=alt.Size("rand", legend=None, scale=alt.Scale(range=[1, 150])),
40
- ))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
+ import torch
3
+ from transformers import AutoTokenizer, AutoModelForCausalLM
4
 
5
+ # ==============================
6
+ # πŸ” HIDE STREAMLIT MENU
7
+ # ==============================
8
+ st.markdown("""
9
+ <style>
10
+ #MainMenu {visibility: hidden;}
11
+ header {visibility: hidden;}
12
+ footer {visibility: hidden;}
13
+ .stDeployButton {display:none;}
14
+ </style>
15
+ """, unsafe_allow_html=True)
16
+
17
+ # ==============================
18
+ # PAGE CONFIG
19
+ # ==============================
20
+ st.set_page_config(page_title="πŸ’» AI Code Generator", layout="wide")
21
+
22
+ # ==============================
23
+ # LOAD MODEL
24
+ # ==============================
25
+ @st.cache_resource
26
+ def load_model():
27
+ model_name = "codellama/CodeLlama-7b-Instruct-hf"
28
+
29
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
30
+ model = AutoModelForCausalLM.from_pretrained(
31
+ model_name,
32
+ torch_dtype=torch.float16,
33
+ device_map="auto"
34
+ )
35
+
36
+ return tokenizer, model
37
+
38
+ tokenizer, model = load_model()
39
+
40
+ # ==============================
41
+ # CODE GENERATION FUNCTION
42
+ # ==============================
43
+ def generate_code(prompt, language):
44
+
45
+ full_prompt = f"""
46
+ You are an expert {language} developer.
47
+
48
+ Write clean, optimized, production-ready code.
49
 
50
+ Task:
51
+ {prompt}
 
52
 
53
+ Rules:
54
+ - Only return code
55
+ - No explanation
56
  """
57
 
58
+ inputs = tokenizer(full_prompt, return_tensors="pt").to(model.device)
59
+
60
+ outputs = model.generate(
61
+ **inputs,
62
+ max_new_tokens=300,
63
+ temperature=0.2,
64
+ top_p=0.9
65
+ )
66
+
67
+ result = tokenizer.decode(outputs[0], skip_special_tokens=True)
68
+
69
+ return result.replace(full_prompt, "").strip()
70
+
71
+ # ==============================
72
+ # UI
73
+ # ==============================
74
+ st.title("πŸ’» AI Code Generator")
75
+
76
+ col1, col2 = st.columns(2)
77
+
78
+ with col1:
79
+ user_prompt = st.text_area("Describe your task", height=200)
80
+
81
+ with col2:
82
+ language = st.selectbox(
83
+ "Select Programming Language",
84
+ ["Python", "JavaScript", "SQL", "Java", "C++", "HTML", "CSS"]
85
+ )
86
+
87
+ if st.button("Generate Code"):
88
+ if not user_prompt.strip():
89
+ st.warning("Please enter a task")
90
+ else:
91
+ with st.spinner("Generating code..."):
92
+ code = generate_code(user_prompt, language)
93
+
94
+ st.success("βœ… Generated Code")
95
+ st.code(code, language=language.lower())