pradeep4321 commited on
Commit
ddf509e
Β·
verified Β·
1 Parent(s): d73111a

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +46 -37
src/streamlit_app.py CHANGED
@@ -20,18 +20,18 @@ footer {visibility: hidden;}
20
  st.set_page_config(page_title="πŸ’» AI Code Generator", layout="wide")
21
 
22
  # ==============================
23
- # LOAD MODEL
24
  # ==============================
25
  @st.cache_resource
26
  def load_model():
27
- model_name = "google/codegemma-2b"
28
 
29
  tokenizer = AutoTokenizer.from_pretrained(model_name)
30
 
31
  model = AutoModelForCausalLM.from_pretrained(
32
  model_name,
33
- torch_dtype=torch.float32,
34
- device_map="cpu"
35
  )
36
 
37
  return tokenizer, model
@@ -39,50 +39,55 @@ def load_model():
39
  tokenizer, model = load_model()
40
 
41
  # ==============================
42
- # CLEAN OUTPUT (IMPORTANT FIX)
43
- # ==============================
44
- def extract_code(text):
45
- # Try to extract code block if exists
46
- if "```" in text:
47
- parts = text.split("```")
48
- if len(parts) >= 2:
49
- return parts[1].strip()
50
-
51
- return text.strip()
52
-
53
- # ==============================
54
- # GENERATE CODE (SIMPLIFIED PROMPT)
55
  # ==============================
56
  def generate_code(prompt, language):
57
 
58
- full_prompt = f"""
59
- Write a {language} function for the following task:
 
 
 
60
 
 
61
  {prompt}
62
 
63
- Only return code.
64
  """
65
 
66
  inputs = tokenizer(full_prompt, return_tensors="pt")
67
 
68
- with torch.no_grad():
69
- outputs = model.generate(
70
- **inputs,
71
- max_new_tokens=120,
72
- do_sample=False,
73
- temperature=0.0
74
- )
75
 
76
  result = tokenizer.decode(outputs[0], skip_special_tokens=True)
77
 
78
- result = result.replace(full_prompt, "").strip()
 
 
 
 
 
 
 
 
 
 
79
 
80
- return extract_code(result)
 
 
 
81
 
82
  # ==============================
83
  # UI
84
  # ==============================
85
- st.title("πŸ’» AI Code Generator (Fast & Accurate)")
86
 
87
  col1, col2 = st.columns(2)
88
 
@@ -96,14 +101,18 @@ with col2:
96
  )
97
 
98
  # ==============================
99
- # BUTTON
100
  # ==============================
101
- if st.button("Generate Code"):
102
  if not user_prompt.strip():
103
- st.warning("⚠️ Please enter a task")
104
  else:
105
- with st.spinner("⚑ Generating clean code..."):
106
- code = generate_code(user_prompt, language)
 
 
 
 
107
 
108
- st.success("βœ… Generated Code")
109
- st.code(code, language=language.lower())
 
20
  st.set_page_config(page_title="πŸ’» AI Code Generator", layout="wide")
21
 
22
  # ==============================
23
+ # LOAD MODEL (FAST + STABLE)
24
  # ==============================
25
  @st.cache_resource
26
  def load_model():
27
+ model_name = "microsoft/phi-2"
28
 
29
  tokenizer = AutoTokenizer.from_pretrained(model_name)
30
 
31
  model = AutoModelForCausalLM.from_pretrained(
32
  model_name,
33
+ torch_dtype=torch.float32, # CPU friendly
34
+ device_map=None
35
  )
36
 
37
  return tokenizer, model
 
39
  tokenizer, model = load_model()
40
 
41
  # ==============================
42
+ # CODE GENERATION FUNCTION
 
 
 
 
 
 
 
 
 
 
 
 
43
  # ==============================
44
  def generate_code(prompt, language):
45
 
46
+ full_prompt = f"""### Instruction:
47
+ Write ONLY valid {language} code.
48
+
49
+ Do not include explanations.
50
+ Do not include special tokens.
51
 
52
+ Task:
53
  {prompt}
54
 
55
+ ### Response:
56
  """
57
 
58
  inputs = tokenizer(full_prompt, return_tensors="pt")
59
 
60
+ outputs = model.generate(
61
+ **inputs,
62
+ max_new_tokens=150,
63
+ temperature=0.2,
64
+ top_p=0.9,
65
+ do_sample=False
66
+ )
67
 
68
  result = tokenizer.decode(outputs[0], skip_special_tokens=True)
69
 
70
+ # Extract response
71
+ if "### Response:" in result:
72
+ result = result.split("### Response:")[-1]
73
+
74
+ # Remove unwanted tokens
75
+ unwanted_tokens = [
76
+ "<|endoftext|>",
77
+ "<|file_separator|>",
78
+ "<|assistant|>",
79
+ "<|system|>"
80
+ ]
81
 
82
+ for token in unwanted_tokens:
83
+ result = result.replace(token, "")
84
+
85
+ return result.strip()
86
 
87
  # ==============================
88
  # UI
89
  # ==============================
90
+ st.title("πŸ’» AI Code Generator (Fast & Clean)")
91
 
92
  col1, col2 = st.columns(2)
93
 
 
101
  )
102
 
103
  # ==============================
104
+ # GENERATE BUTTON
105
  # ==============================
106
+ if st.button("πŸš€ Generate Code"):
107
  if not user_prompt.strip():
108
+ st.warning("Please enter a task")
109
  else:
110
+ with st.spinner("Generating code..."):
111
+ try:
112
+ code = generate_code(user_prompt, language)
113
+
114
+ st.success("βœ… Generated Code")
115
+ st.code(code, language=language.lower())
116
 
117
+ except Exception as e:
118
+ st.error(f"❌ Error: {str(e)}")