pradeep4321 commited on
Commit
812346b
·
verified ·
1 Parent(s): f521540

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +60 -26
src/streamlit_app.py CHANGED
@@ -8,18 +8,18 @@ from transformers import AutoTokenizer, AutoModelForCausalLM
8
  st.set_page_config(page_title="💻 AI Code Generator", layout="wide")
9
 
10
  # ==============================
11
- # LOAD MODEL (DeepSeek - BEST)
12
  # ==============================
13
  @st.cache_resource
14
  def load_model():
15
- model_name = "deepseek-ai/deepseek-coder-1.3b-instruct"
16
 
17
  tokenizer = AutoTokenizer.from_pretrained(model_name)
18
 
19
  model = AutoModelForCausalLM.from_pretrained(
20
  model_name,
21
  torch_dtype=torch.float32,
22
- device_map="cpu"
23
  )
24
 
25
  return tokenizer, model
@@ -27,14 +27,24 @@ def load_model():
27
  tokenizer, model = load_model()
28
 
29
  # ==============================
30
- # AUTO FIX CODE (IMPORTANT)
31
  # ==============================
32
- def fix_incomplete_code(code):
33
- # Fix missing brackets (basic handling)
34
- if code.count("(") > code.count(")"):
35
- code += ")"
36
- if code.count("{") > code.count("}"):
37
- code += "}"
 
 
 
 
 
 
 
 
 
 
38
  return code.strip()
39
 
40
  # ==============================
@@ -43,38 +53,53 @@ def fix_incomplete_code(code):
43
  def generate_code(prompt, language):
44
 
45
  full_prompt = f"""
46
- ### Instruction:
47
- Write a correct {language} function.
48
 
49
- ### Task:
 
 
 
 
 
 
 
 
50
  {prompt}
51
 
52
- ### Response:
53
  """
54
 
55
- inputs = tokenizer(full_prompt, return_tensors="pt")
56
 
57
  with torch.no_grad():
58
  outputs = model.generate(
59
  **inputs,
60
- max_new_tokens=200,
61
- do_sample=False,
62
- temperature=0.0,
 
 
63
  eos_token_id=tokenizer.eos_token_id,
64
  pad_token_id=tokenizer.eos_token_id
65
  )
66
 
67
  result = tokenizer.decode(outputs[0], skip_special_tokens=True)
68
 
69
- # Remove prompt
70
- result = result.replace(full_prompt, "").strip()
 
 
71
 
72
- return fix_incomplete_code(result)
 
 
 
 
73
 
74
  # ==============================
75
  # UI
76
  # ==============================
77
- st.title("💻 AI Code Generator (HF Optimized)")
78
 
79
  col1, col2 = st.columns(2)
80
 
@@ -87,12 +112,21 @@ with col2:
87
  ["Python", "JavaScript", "SQL", "Java", "C++", "HTML", "CSS"]
88
  )
89
 
 
 
 
90
  if st.button("Generate Code"):
91
  if not user_prompt.strip():
92
- st.warning("⚠️ Enter a task")
93
  else:
94
- with st.spinner("⚡ Generating code..."):
95
  code = generate_code(user_prompt, language)
96
 
97
- st.success("✅ Generated Code")
98
- st.code(code, language=language.lower())
 
 
 
 
 
 
 
8
  st.set_page_config(page_title="💻 AI Code Generator", layout="wide")
9
 
10
  # ==============================
11
+ # LOAD MODEL
12
  # ==============================
13
  @st.cache_resource
14
  def load_model():
15
+ model_name = "deepseek-ai/deepseek-coder-6.7b-instruct" # 🔥 BEST MODEL
16
 
17
  tokenizer = AutoTokenizer.from_pretrained(model_name)
18
 
19
  model = AutoModelForCausalLM.from_pretrained(
20
  model_name,
21
  torch_dtype=torch.float32,
22
+ device_map="auto" # auto uses GPU if available
23
  )
24
 
25
  return tokenizer, model
 
27
  tokenizer, model = load_model()
28
 
29
  # ==============================
30
+ # CLEAN OUTPUT
31
  # ==============================
32
+ def clean_code(code):
33
+ code = code.strip()
34
+
35
+ remove_words = [
36
+ "Explanation:",
37
+ "Here is the code:",
38
+ "Output:",
39
+ "Answer:"
40
+ ]
41
+
42
+ for word in remove_words:
43
+ code = code.replace(word, "")
44
+
45
+ # Remove markdown formatting
46
+ code = code.replace("```python", "").replace("```", "")
47
+
48
  return code.strip()
49
 
50
  # ==============================
 
53
  def generate_code(prompt, language):
54
 
55
  full_prompt = f"""
56
+ You are an expert {language} developer.
 
57
 
58
+ Generate clean, correct, and complete code.
59
+
60
+ Rules:
61
+ - Only return code
62
+ - No explanation
63
+ - Proper syntax
64
+ - Complete working solution
65
+
66
+ Task:
67
  {prompt}
68
 
69
+ Code:
70
  """
71
 
72
+ inputs = tokenizer(full_prompt, return_tensors="pt", truncation=True)
73
 
74
  with torch.no_grad():
75
  outputs = model.generate(
76
  **inputs,
77
+ max_new_tokens=350,
78
+ do_sample=True,
79
+ temperature=0.2,
80
+ top_p=0.9,
81
+ repetition_penalty=1.1,
82
  eos_token_id=tokenizer.eos_token_id,
83
  pad_token_id=tokenizer.eos_token_id
84
  )
85
 
86
  result = tokenizer.decode(outputs[0], skip_special_tokens=True)
87
 
88
+ if "Code:" in result:
89
+ result = result.split("Code:")[-1]
90
+
91
+ return clean_code(result)
92
 
93
+ # ==============================
94
+ # SESSION STATE (CHAT HISTORY)
95
+ # ==============================
96
+ if "history" not in st.session_state:
97
+ st.session_state.history = []
98
 
99
  # ==============================
100
  # UI
101
  # ==============================
102
+ st.title("💻 AI Code Generator (Advanced)")
103
 
104
  col1, col2 = st.columns(2)
105
 
 
112
  ["Python", "JavaScript", "SQL", "Java", "C++", "HTML", "CSS"]
113
  )
114
 
115
+ # ==============================
116
+ # GENERATE BUTTON
117
+ # ==============================
118
  if st.button("Generate Code"):
119
  if not user_prompt.strip():
120
+ st.warning("⚠️ Please enter a task")
121
  else:
122
+ with st.spinner("⚡ Generating high-quality code..."):
123
  code = generate_code(user_prompt, language)
124
 
125
+ st.session_state.history.append((user_prompt, code))
126
+
127
+ # ==============================
128
+ # DISPLAY HISTORY
129
+ # ==============================
130
+ for q, c in reversed(st.session_state.history):
131
+ st.markdown(f"### 🧑 Task:\n{q}")
132
+ st.code(c, language=language.lower())