pradeep4321 commited on
Commit
5d9a347
Β·
verified Β·
1 Parent(s): dd6b048

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +40 -19
src/streamlit_app.py CHANGED
@@ -1,6 +1,7 @@
1
  import streamlit as st
2
  import torch
3
  from transformers import AutoTokenizer, AutoModelForCausalLM
 
4
 
5
  # ==============================
6
  # πŸ” HIDE STREAMLIT MENU
@@ -20,18 +21,31 @@ footer {visibility: hidden;}
20
  st.set_page_config(page_title="πŸ’» AI Code Generator", layout="wide")
21
 
22
  # ==============================
23
- # LOAD MODEL (OPTIMIZED)
 
 
 
 
 
 
 
 
 
24
  # ==============================
25
  @st.cache_resource
26
  def load_model():
27
  model_name = "google/codegemma-2b"
28
 
29
- tokenizer = AutoTokenizer.from_pretrained(model_name)
 
 
 
30
 
31
  model = AutoModelForCausalLM.from_pretrained(
32
  model_name,
33
- torch_dtype=torch.float32, # CPU safe
34
- device_map="cpu" # force CPU (faster on Spaces)
 
35
  )
36
 
37
  return tokenizer, model
@@ -56,25 +70,25 @@ Rules:
56
  - No explanation
57
  """
58
 
59
- inputs = tokenizer(full_prompt, return_tensors="pt")
60
 
61
- with torch.no_grad():
62
- outputs = model.generate(
63
- **inputs,
64
- max_new_tokens=150, # reduced for speed
65
- temperature=0.1,
66
- top_p=0.85,
67
- do_sample=True
68
- )
69
 
70
  result = tokenizer.decode(outputs[0], skip_special_tokens=True)
71
 
 
72
  return result.replace(full_prompt, "").strip()
73
 
74
  # ==============================
75
  # UI
76
  # ==============================
77
- st.title("πŸ’» AI Code Generator")
78
 
79
  col1, col2 = st.columns(2)
80
 
@@ -87,12 +101,19 @@ with col2:
87
  ["Python", "JavaScript", "SQL", "Java", "C++", "HTML", "CSS"]
88
  )
89
 
90
- if st.button("Generate Code"):
 
 
 
91
  if not user_prompt.strip():
92
  st.warning("Please enter a task")
93
  else:
94
- with st.spinner("⚑ Generating fast code..."):
95
- code = generate_code(user_prompt, language)
 
 
 
 
96
 
97
- st.success("βœ… Generated Code")
98
- st.code(code, language=language.lower())
 
1
  import streamlit as st
2
  import torch
3
  from transformers import AutoTokenizer, AutoModelForCausalLM
4
+ import os
5
 
6
  # ==============================
7
  # πŸ” HIDE STREAMLIT MENU
 
21
  st.set_page_config(page_title="πŸ’» AI Code Generator", layout="wide")
22
 
23
  # ==============================
24
+ # HF TOKEN (SECURE)
25
+ # ==============================
26
+ HF_TOKEN = os.environ.get("HF_TOKEN")
27
+
28
+ if not HF_TOKEN:
29
+ st.error("❌ HF_TOKEN not found. Add it in Hugging Face Secrets.")
30
+ st.stop()
31
+
32
+ # ==============================
33
+ # LOAD MODEL (CACHED)
34
  # ==============================
35
  @st.cache_resource
36
  def load_model():
37
  model_name = "google/codegemma-2b"
38
 
39
+ tokenizer = AutoTokenizer.from_pretrained(
40
+ model_name,
41
+ token=HF_TOKEN
42
+ )
43
 
44
  model = AutoModelForCausalLM.from_pretrained(
45
  model_name,
46
+ token=HF_TOKEN,
47
+ torch_dtype=torch.float16,
48
+ device_map="auto"
49
  )
50
 
51
  return tokenizer, model
 
70
  - No explanation
71
  """
72
 
73
+ inputs = tokenizer(full_prompt, return_tensors="pt").to(model.device)
74
 
75
+ outputs = model.generate(
76
+ **inputs,
77
+ max_new_tokens=300,
78
+ temperature=0.2,
79
+ top_p=0.9,
80
+ do_sample=True
81
+ )
 
82
 
83
  result = tokenizer.decode(outputs[0], skip_special_tokens=True)
84
 
85
+ # Remove prompt from output
86
  return result.replace(full_prompt, "").strip()
87
 
88
  # ==============================
89
  # UI
90
  # ==============================
91
+ st.title("πŸ’» AI Code Generator (CodeGemma)")
92
 
93
  col1, col2 = st.columns(2)
94
 
 
101
  ["Python", "JavaScript", "SQL", "Java", "C++", "HTML", "CSS"]
102
  )
103
 
104
+ # ==============================
105
+ # GENERATE BUTTON
106
+ # ==============================
107
+ if st.button("πŸš€ Generate Code"):
108
  if not user_prompt.strip():
109
  st.warning("Please enter a task")
110
  else:
111
+ with st.spinner("Generating code..."):
112
+ try:
113
+ code = generate_code(user_prompt, language)
114
+
115
+ st.success("βœ… Generated Code")
116
+ st.code(code, language=language.lower())
117
 
118
+ except Exception as e:
119
+ st.error(f"❌ Error: {str(e)}")