pradeep4321 commited on
Commit
dd6b048
·
verified ·
1 Parent(s): 7087b82

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +17 -14
src/streamlit_app.py CHANGED
@@ -20,17 +20,18 @@ footer {visibility: hidden;}
20
  st.set_page_config(page_title="💻 AI Code Generator", layout="wide")
21
 
22
  # ==============================
23
- # LOAD MODEL
24
  # ==============================
25
  @st.cache_resource
26
  def load_model():
27
- model_name = "codellama/CodeLlama-7b-Instruct-hf"
28
 
29
  tokenizer = AutoTokenizer.from_pretrained(model_name)
 
30
  model = AutoModelForCausalLM.from_pretrained(
31
  model_name,
32
- torch_dtype=torch.float16,
33
- device_map="auto"
34
  )
35
 
36
  return tokenizer, model
@@ -55,14 +56,16 @@ Rules:
55
  - No explanation
56
  """
57
 
58
- inputs = tokenizer(full_prompt, return_tensors="pt").to(model.device)
59
 
60
- outputs = model.generate(
61
- **inputs,
62
- max_new_tokens=300,
63
- temperature=0.2,
64
- top_p=0.9
65
- )
 
 
66
 
67
  result = tokenizer.decode(outputs[0], skip_special_tokens=True)
68
 
@@ -88,8 +91,8 @@ if st.button("Generate Code"):
88
  if not user_prompt.strip():
89
  st.warning("Please enter a task")
90
  else:
91
- with st.spinner("Generating code..."):
92
  code = generate_code(user_prompt, language)
93
 
94
- st.success("✅ Generated Code")
95
- st.code(code, language=language.lower())
 
20
  st.set_page_config(page_title="💻 AI Code Generator", layout="wide")
21
 
22
  # ==============================
23
+ # LOAD MODEL (OPTIMIZED)
24
  # ==============================
25
  @st.cache_resource
26
  def load_model():
27
+ model_name = "google/codegemma-2b"
28
 
29
  tokenizer = AutoTokenizer.from_pretrained(model_name)
30
+
31
  model = AutoModelForCausalLM.from_pretrained(
32
  model_name,
33
+ torch_dtype=torch.float32, # CPU safe
34
+ device_map="cpu" # force CPU (faster on Spaces)
35
  )
36
 
37
  return tokenizer, model
 
56
  - No explanation
57
  """
58
 
59
+ inputs = tokenizer(full_prompt, return_tensors="pt")
60
 
61
+ with torch.no_grad():
62
+ outputs = model.generate(
63
+ **inputs,
64
+ max_new_tokens=150, # reduced for speed
65
+ temperature=0.1,
66
+ top_p=0.85,
67
+ do_sample=True
68
+ )
69
 
70
  result = tokenizer.decode(outputs[0], skip_special_tokens=True)
71
 
 
91
  if not user_prompt.strip():
92
  st.warning("Please enter a task")
93
  else:
94
+ with st.spinner("Generating fast code..."):
95
  code = generate_code(user_prompt, language)
96
 
97
+ st.success("✅ Generated Code")
98
+ st.code(code, language=language.lower())