pradeep4321 commited on
Commit
d144157
Β·
verified Β·
1 Parent(s): 45b1038

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +51 -32
src/streamlit_app.py CHANGED
@@ -20,18 +20,18 @@ footer {visibility: hidden;}
20
  st.set_page_config(page_title="πŸ’» AI Code Generator", layout="wide")
21
 
22
  # ==============================
23
- # LOAD MODEL (FAST)
24
  # ==============================
25
  @st.cache_resource
26
  def load_model():
27
- model_name = "microsoft/phi-2"
28
 
29
  tokenizer = AutoTokenizer.from_pretrained(model_name)
30
 
31
  model = AutoModelForCausalLM.from_pretrained(
32
  model_name,
33
- torch_dtype=torch.float32, # CPU friendly
34
- device_map=None # avoid GPU issues
35
  )
36
 
37
  return tokenizer, model
@@ -39,36 +39,57 @@ def load_model():
39
  tokenizer, model = load_model()
40
 
41
  # ==============================
42
- # CODE GENERATION FUNCTION
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  # ==============================
44
  def generate_code(prompt, language):
45
 
46
- full_prompt = f"""### Instruction:
47
- Write a {language} program for the following task.
48
 
49
- Task:
50
- {prompt}
 
 
 
 
51
 
52
- ### Response:
 
 
 
 
 
 
 
 
53
  """
54
 
55
  inputs = tokenizer(full_prompt, return_tensors="pt")
56
 
57
- outputs = model.generate(
58
- **inputs,
59
- max_new_tokens=150,
60
- temperature=0.2,
61
- top_p=0.9,
62
- do_sample=False
63
- )
64
 
65
  result = tokenizer.decode(outputs[0], skip_special_tokens=True)
66
 
67
- # Extract only response part
68
- if "### Response:" in result:
69
- result = result.split("### Response:")[-1]
70
-
71
- return result.strip()
72
 
73
  # ==============================
74
  # UI
@@ -87,18 +108,16 @@ with col2:
87
  )
88
 
89
  # ==============================
90
- # GENERATE BUTTON
91
  # ==============================
92
- if st.button("πŸš€ Generate Code"):
93
  if not user_prompt.strip():
94
- st.warning("Please enter a task")
95
  else:
96
- with st.spinner("Generating code..."):
97
- try:
98
- code = generate_code(user_prompt, language)
99
 
100
- st.success("βœ… Generated Code")
101
- st.code(code, language=language.lower())
102
 
103
- except Exception as e:
104
- st.error(f"❌ Error: {str(e)}")
 
20
  st.set_page_config(page_title="πŸ’» AI Code Generator", layout="wide")
21
 
22
  # ==============================
23
+ # LOAD MODEL (FAST + CPU SAFE)
24
  # ==============================
25
  @st.cache_resource
26
  def load_model():
27
+ model_name = "google/codegemma-2b"
28
 
29
  tokenizer = AutoTokenizer.from_pretrained(model_name)
30
 
31
  model = AutoModelForCausalLM.from_pretrained(
32
  model_name,
33
+ torch_dtype=torch.float32, # CPU safe
34
+ device_map="cpu" # force CPU for Spaces
35
  )
36
 
37
  return tokenizer, model
 
39
  tokenizer, model = load_model()
40
 
41
  # ==============================
42
+ # CLEAN OUTPUT FUNCTION
43
+ # ==============================
44
+ def clean_output(text, prompt):
45
+ text = text.replace(prompt, "").strip()
46
+
47
+ # Remove unwanted prefixes if model adds them
48
+ unwanted_phrases = ["```", "code:", "Code:"]
49
+ for phrase in unwanted_phrases:
50
+ text = text.replace(phrase, "")
51
+
52
+ return text.strip()
53
+
54
+ # ==============================
55
+ # CODE GENERATION FUNCTION (FIXED)
56
  # ==============================
57
  def generate_code(prompt, language):
58
 
59
+ full_prompt = f"""
60
+ You are a highly accurate {language} code generator.
61
 
62
+ Example:
63
+ Task: add two numbers
64
+ Code:
65
+ function add(a, b) {{
66
+ return a + b;
67
+ }}
68
 
69
+ Now solve:
70
+
71
+ Task: {prompt}
72
+
73
+ Instructions:
74
+ - Generate correct and complete {language} code
75
+ - Do exactly what is asked
76
+ - Do NOT change the logic
77
+ - Return ONLY code
78
  """
79
 
80
  inputs = tokenizer(full_prompt, return_tensors="pt")
81
 
82
+ with torch.no_grad():
83
+ outputs = model.generate(
84
+ **inputs,
85
+ max_new_tokens=200, # prevent truncation
86
+ do_sample=False, # deterministic (IMPORTANT)
87
+ temperature=0.0
88
+ )
89
 
90
  result = tokenizer.decode(outputs[0], skip_special_tokens=True)
91
 
92
+ return clean_output(result, full_prompt)
 
 
 
 
93
 
94
  # ==============================
95
  # UI
 
108
  )
109
 
110
  # ==============================
111
+ # BUTTON ACTION
112
  # ==============================
113
+ if st.button("Generate Code"):
114
  if not user_prompt.strip():
115
+ st.warning("⚠️ Please enter a task")
116
  else:
117
+ with st.spinner("⚑ Generating high-quality code..."):
118
+ code = generate_code(user_prompt, language)
 
119
 
120
+ st.success("βœ… Generated Code")
 
121
 
122
+ # Display properly formatted code
123
+ st.code(code, language=language.lower())