pradeep4321 commited on
Commit
d73111a
·
verified ·
1 Parent(s): d144157

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +21 -35
src/streamlit_app.py CHANGED
@@ -20,7 +20,7 @@ footer {visibility: hidden;}
20
  st.set_page_config(page_title="💻 AI Code Generator", layout="wide")
21
 
22
  # ==============================
23
- # LOAD MODEL (FAST + CPU SAFE)
24
  # ==============================
25
  @st.cache_resource
26
  def load_model():
@@ -30,8 +30,8 @@ def load_model():
30
 
31
  model = AutoModelForCausalLM.from_pretrained(
32
  model_name,
33
- torch_dtype=torch.float32, # CPU safe
34
- device_map="cpu" # force CPU for Spaces
35
  )
36
 
37
  return tokenizer, model
@@ -39,42 +39,28 @@ def load_model():
39
  tokenizer, model = load_model()
40
 
41
  # ==============================
42
- # CLEAN OUTPUT FUNCTION
43
  # ==============================
44
- def clean_output(text, prompt):
45
- text = text.replace(prompt, "").strip()
46
-
47
- # Remove unwanted prefixes if model adds them
48
- unwanted_phrases = ["```", "code:", "Code:"]
49
- for phrase in unwanted_phrases:
50
- text = text.replace(phrase, "")
51
 
52
  return text.strip()
53
 
54
  # ==============================
55
- # CODE GENERATION FUNCTION (FIXED)
56
  # ==============================
57
  def generate_code(prompt, language):
58
 
59
  full_prompt = f"""
60
- You are a highly accurate {language} code generator.
61
-
62
- Example:
63
- Task: add two numbers
64
- Code:
65
- function add(a, b) {{
66
- return a + b;
67
- }}
68
-
69
- Now solve:
70
 
71
- Task: {prompt}
72
 
73
- Instructions:
74
- - Generate correct and complete {language} code
75
- - Do exactly what is asked
76
- - Do NOT change the logic
77
- - Return ONLY code
78
  """
79
 
80
  inputs = tokenizer(full_prompt, return_tensors="pt")
@@ -82,14 +68,16 @@ Instructions:
82
  with torch.no_grad():
83
  outputs = model.generate(
84
  **inputs,
85
- max_new_tokens=200, # prevent truncation
86
- do_sample=False, # deterministic (IMPORTANT)
87
  temperature=0.0
88
  )
89
 
90
  result = tokenizer.decode(outputs[0], skip_special_tokens=True)
91
 
92
- return clean_output(result, full_prompt)
 
 
93
 
94
  # ==============================
95
  # UI
@@ -108,16 +96,14 @@ with col2:
108
  )
109
 
110
  # ==============================
111
- # BUTTON ACTION
112
  # ==============================
113
  if st.button("Generate Code"):
114
  if not user_prompt.strip():
115
  st.warning("⚠️ Please enter a task")
116
  else:
117
- with st.spinner("⚡ Generating high-quality code..."):
118
  code = generate_code(user_prompt, language)
119
 
120
  st.success("✅ Generated Code")
121
-
122
- # Display properly formatted code
123
  st.code(code, language=language.lower())
 
20
  st.set_page_config(page_title="💻 AI Code Generator", layout="wide")
21
 
22
  # ==============================
23
+ # LOAD MODEL
24
  # ==============================
25
  @st.cache_resource
26
  def load_model():
 
30
 
31
  model = AutoModelForCausalLM.from_pretrained(
32
  model_name,
33
+ torch_dtype=torch.float32,
34
+ device_map="cpu"
35
  )
36
 
37
  return tokenizer, model
 
39
  tokenizer, model = load_model()
40
 
41
  # ==============================
42
+ # CLEAN OUTPUT (IMPORTANT FIX)
43
  # ==============================
44
+ def extract_code(text):
45
+ # Try to extract code block if exists
46
+ if "```" in text:
47
+ parts = text.split("```")
48
+ if len(parts) >= 2:
49
+ return parts[1].strip()
 
50
 
51
  return text.strip()
52
 
53
  # ==============================
54
+ # GENERATE CODE (SIMPLIFIED PROMPT)
55
  # ==============================
56
  def generate_code(prompt, language):
57
 
58
  full_prompt = f"""
59
+ Write a {language} function for the following task:
 
 
 
 
 
 
 
 
 
60
 
61
+ {prompt}
62
 
63
+ Only return code.
 
 
 
 
64
  """
65
 
66
  inputs = tokenizer(full_prompt, return_tensors="pt")
 
68
  with torch.no_grad():
69
  outputs = model.generate(
70
  **inputs,
71
+ max_new_tokens=120,
72
+ do_sample=False,
73
  temperature=0.0
74
  )
75
 
76
  result = tokenizer.decode(outputs[0], skip_special_tokens=True)
77
 
78
+ result = result.replace(full_prompt, "").strip()
79
+
80
+ return extract_code(result)
81
 
82
  # ==============================
83
  # UI
 
96
  )
97
 
98
  # ==============================
99
+ # BUTTON
100
  # ==============================
101
  if st.button("Generate Code"):
102
  if not user_prompt.strip():
103
  st.warning("⚠️ Please enter a task")
104
  else:
105
+ with st.spinner("⚡ Generating clean code..."):
106
  code = generate_code(user_prompt, language)
107
 
108
  st.success("✅ Generated Code")
 
 
109
  st.code(code, language=language.lower())