waleed-12 commited on
Commit
72419b9
·
verified ·
1 Parent(s): 14b6d9b

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +38 -6
src/streamlit_app.py CHANGED
@@ -1,23 +1,55 @@
1
  import streamlit as st
2
  from transformers import AutoTokenizer, AutoModelForCausalLM
3
  import torch
 
4
 
5
  # Load model and tokenizer
6
  MODEL_NAME = "Johannes/code-generation-model-fine-tuned-to-produce-good-code-snippets"
7
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
8
  model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)
9
 
10
- st.title("Code Generation Model")
11
 
12
- prompt = st.text_area("Enter a code prompt:")
13
 
14
- max_length = st.slider("Maximum generated tokens", min_value=50, max_value=500, value=150)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
  if st.button("Generate Code"):
17
  if prompt.strip() != "":
18
  inputs = tokenizer(prompt, return_tensors="pt")
19
- outputs = model.generate(**inputs, max_length=max_length, do_sample=True, temperature=0.7)
 
 
 
 
 
20
  generated_code = tokenizer.decode(outputs[0], skip_special_tokens=True)
21
- st.code(generated_code, language="python")
 
 
 
 
 
22
  else:
23
- st.warning("Please enter a prompt.")
 
1
  import streamlit as st
2
  from transformers import AutoTokenizer, AutoModelForCausalLM
3
  import torch
4
+ import re
5
 
6
  # Load model and tokenizer
7
  MODEL_NAME = "Johannes/code-generation-model-fine-tuned-to-produce-good-code-snippets"
8
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
9
  model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)
10
 
11
+ st.title("Pseudo-code to Code")
12
 
13
+ prompt = st.text_area("Enter a code:")
14
 
15
+
16
+ def remove_comments_and_headers(code_text):
17
+ """
18
+ Removes:
19
+ - Python-style comments starting with #
20
+ - C/C++ style comments starting with //
21
+ - License or header blocks
22
+ """
23
+ lines = code_text.split("\n")
24
+ code_lines = []
25
+
26
+ for line in lines:
27
+ stripped = line.strip()
28
+ # Skip empty lines or comment lines
29
+ if stripped == "":
30
+ continue
31
+ if stripped.startswith("#") or stripped.startswith("//"):
32
+ continue
33
+ code_lines.append(line)
34
+
35
+ # Join remaining lines
36
+ return "\n".join(code_lines)
37
 
38
  if st.button("Generate Code"):
39
  if prompt.strip() != "":
40
  inputs = tokenizer(prompt, return_tensors="pt")
41
+ outputs = model.generate(
42
+ **inputs,
43
+ max_length=max_length,
44
+ do_sample=True,
45
+ temperature=0.7
46
+ )
47
  generated_code = tokenizer.decode(outputs[0], skip_special_tokens=True)
48
+
49
+ # Clean code by removing comments / license headers
50
+ cleaned_code = remove_comments_and_headers(generated_code)
51
+
52
+ st.subheader("Generated Code:")
53
+ st.code(cleaned_code, language="python")
54
  else:
55
+ st.warning("Enter prompt.")