pradeep4321 commited on
Commit
f521540
Β·
verified Β·
1 Parent(s): ddf509e

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +39 -59
src/streamlit_app.py CHANGED
@@ -2,36 +2,24 @@ import streamlit as st
2
  import torch
3
  from transformers import AutoTokenizer, AutoModelForCausalLM
4
 
5
- # ==============================
6
- # πŸ” HIDE STREAMLIT MENU
7
- # ==============================
8
- st.markdown("""
9
- <style>
10
- #MainMenu {visibility: hidden;}
11
- header {visibility: hidden;}
12
- footer {visibility: hidden;}
13
- .stDeployButton {display:none;}
14
- </style>
15
- """, unsafe_allow_html=True)
16
-
17
  # ==============================
18
  # PAGE CONFIG
19
  # ==============================
20
  st.set_page_config(page_title="πŸ’» AI Code Generator", layout="wide")
21
 
22
  # ==============================
23
- # LOAD MODEL (FAST + STABLE)
24
  # ==============================
25
  @st.cache_resource
26
  def load_model():
27
- model_name = "microsoft/phi-2"
28
 
29
  tokenizer = AutoTokenizer.from_pretrained(model_name)
30
 
31
  model = AutoModelForCausalLM.from_pretrained(
32
  model_name,
33
- torch_dtype=torch.float32, # CPU friendly
34
- device_map=None
35
  )
36
 
37
  return tokenizer, model
@@ -39,17 +27,26 @@ def load_model():
39
  tokenizer, model = load_model()
40
 
41
  # ==============================
42
- # CODE GENERATION FUNCTION
43
  # ==============================
44
- def generate_code(prompt, language):
 
 
 
 
 
 
45
 
46
- full_prompt = f"""### Instruction:
47
- Write ONLY valid {language} code.
 
 
48
 
49
- Do not include explanations.
50
- Do not include special tokens.
 
51
 
52
- Task:
53
  {prompt}
54
 
55
  ### Response:
@@ -57,37 +54,27 @@ Task:
57
 
58
  inputs = tokenizer(full_prompt, return_tensors="pt")
59
 
60
- outputs = model.generate(
61
- **inputs,
62
- max_new_tokens=150,
63
- temperature=0.2,
64
- top_p=0.9,
65
- do_sample=False
66
- )
 
 
67
 
68
  result = tokenizer.decode(outputs[0], skip_special_tokens=True)
69
 
70
- # Extract response
71
- if "### Response:" in result:
72
- result = result.split("### Response:")[-1]
73
-
74
- # Remove unwanted tokens
75
- unwanted_tokens = [
76
- "<|endoftext|>",
77
- "<|file_separator|>",
78
- "<|assistant|>",
79
- "<|system|>"
80
- ]
81
-
82
- for token in unwanted_tokens:
83
- result = result.replace(token, "")
84
 
85
- return result.strip()
86
 
87
  # ==============================
88
  # UI
89
  # ==============================
90
- st.title("πŸ’» AI Code Generator (Fast & Clean)")
91
 
92
  col1, col2 = st.columns(2)
93
 
@@ -100,19 +87,12 @@ with col2:
100
  ["Python", "JavaScript", "SQL", "Java", "C++", "HTML", "CSS"]
101
  )
102
 
103
- # ==============================
104
- # GENERATE BUTTON
105
- # ==============================
106
- if st.button("πŸš€ Generate Code"):
107
  if not user_prompt.strip():
108
- st.warning("Please enter a task")
109
  else:
110
- with st.spinner("Generating code..."):
111
- try:
112
- code = generate_code(user_prompt, language)
113
-
114
- st.success("βœ… Generated Code")
115
- st.code(code, language=language.lower())
116
 
117
- except Exception as e:
118
- st.error(f"❌ Error: {str(e)}")
 
2
  import torch
3
  from transformers import AutoTokenizer, AutoModelForCausalLM
4
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  # ==============================
6
  # PAGE CONFIG
7
  # ==============================
8
  st.set_page_config(page_title="πŸ’» AI Code Generator", layout="wide")
9
 
10
  # ==============================
11
+ # LOAD MODEL (DeepSeek - BEST)
12
  # ==============================
13
  @st.cache_resource
14
  def load_model():
15
+ model_name = "deepseek-ai/deepseek-coder-1.3b-instruct"
16
 
17
  tokenizer = AutoTokenizer.from_pretrained(model_name)
18
 
19
  model = AutoModelForCausalLM.from_pretrained(
20
  model_name,
21
+ torch_dtype=torch.float32,
22
+ device_map="cpu"
23
  )
24
 
25
  return tokenizer, model
 
27
  tokenizer, model = load_model()
28
 
29
  # ==============================
30
+ # AUTO FIX CODE (IMPORTANT)
31
  # ==============================
32
+ def fix_incomplete_code(code):
33
+ # Fix missing brackets (basic handling)
34
+ if code.count("(") > code.count(")"):
35
+ code += ")"
36
+ if code.count("{") > code.count("}"):
37
+ code += "}"
38
+ return code.strip()
39
 
40
+ # ==============================
41
+ # GENERATE CODE
42
+ # ==============================
43
+ def generate_code(prompt, language):
44
 
45
+ full_prompt = f"""
46
+ ### Instruction:
47
+ Write a correct {language} function.
48
 
49
+ ### Task:
50
  {prompt}
51
 
52
  ### Response:
 
54
 
55
  inputs = tokenizer(full_prompt, return_tensors="pt")
56
 
57
+ with torch.no_grad():
58
+ outputs = model.generate(
59
+ **inputs,
60
+ max_new_tokens=200,
61
+ do_sample=False,
62
+ temperature=0.0,
63
+ eos_token_id=tokenizer.eos_token_id,
64
+ pad_token_id=tokenizer.eos_token_id
65
+ )
66
 
67
  result = tokenizer.decode(outputs[0], skip_special_tokens=True)
68
 
69
+ # Remove prompt
70
+ result = result.replace(full_prompt, "").strip()
 
 
 
 
 
 
 
 
 
 
 
 
71
 
72
+ return fix_incomplete_code(result)
73
 
74
  # ==============================
75
  # UI
76
  # ==============================
77
+ st.title("πŸ’» AI Code Generator (HF Optimized)")
78
 
79
  col1, col2 = st.columns(2)
80
 
 
87
  ["Python", "JavaScript", "SQL", "Java", "C++", "HTML", "CSS"]
88
  )
89
 
90
+ if st.button("Generate Code"):
 
 
 
91
  if not user_prompt.strip():
92
+ st.warning("⚠️ Enter a task")
93
  else:
94
+ with st.spinner("⚑ Generating code..."):
95
+ code = generate_code(user_prompt, language)
 
 
 
 
96
 
97
+ st.success("βœ… Generated Code")
98
+ st.code(code, language=language.lower())