shaheerawan3 commited on
Commit
615a9ed
·
verified ·
1 Parent(s): b5faa78

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +157 -107
app.py CHANGED
@@ -2,114 +2,128 @@ import streamlit as st
2
  from langchain_community.embeddings import HuggingFaceEmbeddings
3
  from langchain_community.vectorstores import Chroma
4
  from langchain.text_splitter import RecursiveCharacterTextSplitter
5
- from langchain_community.llms import CTransformers
6
- from langchain.chains import RetrievalQA
7
- from langchain.prompts import PromptTemplate
8
- import os
9
- from pathlib import Path
10
  import logging
 
11
 
12
- class LocalWebDevRAG:
13
  def __init__(self):
14
  self.initialize_logging()
 
15
  self.setup_embeddings()
16
- self.setup_llm()
17
  self.initialize_vector_store()
18
 
19
  def initialize_logging(self):
20
  logging.basicConfig(level=logging.INFO)
21
  self.logger = logging.getLogger(__name__)
22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  def setup_embeddings(self):
24
  self.embeddings = HuggingFaceEmbeddings(
25
  model_name="all-MiniLM-L6-v2",
26
  model_kwargs={'device': 'cpu'}
27
  )
28
-
29
- def setup_llm(self):
30
- # Using CodeLlama local model
31
- llm_config = {
32
- 'model': 'codellama-7b-instruct.ggmlv3.Q4_K_M.bin',
33
- 'model_type': 'llama',
34
- 'max_new_tokens': 2048,
35
- 'temperature': 0.7,
36
- 'context_length': 2048,
37
- }
38
-
39
- self.llm = CTransformers(**llm_config)
40
-
41
- self.qa_prompt = PromptTemplate(
42
- template="""You are an expert web developer. Based on the context and request,
43
- generate production-ready code.
44
-
45
- Context: {context}
46
- Question: {question}
47
-
48
- Provide a detailed solution with explanations.""",
49
- input_variables=["context", "question"]
50
- )
51
 
52
  def initialize_vector_store(self):
53
- try:
54
- # Create or load vector store
55
- if not Path("chroma_db").exists():
56
- self.create_new_vector_store()
57
- else:
58
- self.vector_store = Chroma(
59
- persist_directory="chroma_db",
60
- embedding_function=self.embeddings
61
- )
62
- self.logger.info("Loaded existing vector store")
63
-
64
- self.qa_chain = RetrievalQA.from_chain_type(
65
- llm=self.llm,
66
- chain_type="stuff",
67
- retriever=self.vector_store.as_retriever(),
68
- chain_type_kwargs={"prompt": self.qa_prompt}
69
- )
70
-
71
- except Exception as e:
72
- self.logger.error(f"Vector store initialization failed: {e}")
73
- raise
74
 
75
- def create_new_vector_store(self):
76
- # Example code snippets and documentation
77
  documents = [
78
- "React component best practices...",
79
- "API security implementations...",
80
- "Database schema designs...",
81
- # Add more code examples and documentation
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
  ]
83
 
84
  text_splitter = RecursiveCharacterTextSplitter(
85
- chunk_size=1000,
86
- chunk_overlap=200
87
  )
88
  texts = text_splitter.split_text('\n\n'.join(documents))
89
 
90
- self.vector_store = Chroma.from_texts(
91
  texts,
92
  self.embeddings,
93
  persist_directory="chroma_db"
94
  )
95
- self.logger.info("Created new vector store")
96
 
97
- def generate_code(self, description, tech_stack, requirements):
98
  try:
99
- prompt = f"""
100
- Create a web application with:
101
  Description: {description}
102
- Tech Stack: {tech_stack}
103
- Requirements: {requirements}
104
 
105
- Provide:
106
- 1. Frontend components
107
- 2. Backend API
108
- 3. Database schema
109
- 4. Setup instructions
110
  """
111
 
112
- response = self.qa_chain.run(prompt)
 
 
 
 
 
 
 
 
 
 
 
 
113
  return self.process_response(response)
114
 
115
  except Exception as e:
@@ -117,25 +131,44 @@ class LocalWebDevRAG:
117
  raise
118
 
119
  def process_response(self, response):
120
- # Basic response processing
121
- return {
122
- "frontend": response.split("Frontend:")[1].split("Backend:")[0] if "Frontend:" in response else "",
123
- "backend": response.split("Backend:")[1].split("Database:")[0] if "Backend:" in response else "",
124
- "database": response.split("Database:")[1].split("Setup:")[0] if "Database:" in response else "",
125
- "setup": response.split("Setup:")[1] if "Setup:" in response else response
126
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
127
 
128
  def main():
129
- st.set_page_config(page_title="Local Web Development AI", layout="wide")
130
 
131
- st.title("🚀 Web Development AI Assistant")
132
- st.write("Generate web applications using local AI - no API key required!")
133
-
134
- if 'rag_system' not in st.session_state:
135
- with st.spinner("Initializing AI system... (this may take a few minutes on first run)"):
136
- st.session_state.rag_system = LocalWebDevRAG()
137
-
138
- with st.form("project_specs"):
139
  description = st.text_area(
140
  "Project Description",
141
  placeholder="Describe your web application..."
@@ -144,30 +177,32 @@ def main():
144
  col1, col2 = st.columns(2)
145
  with col1:
146
  frontend = st.selectbox(
147
- "Frontend",
148
- ["React", "Vue", "Angular"]
149
  )
150
- database = st.selectbox(
151
- "Database",
152
- ["MongoDB", "PostgreSQL", "MySQL"]
 
153
  )
154
 
155
  with col2:
156
- backend = st.selectbox(
157
- "Backend",
158
- ["Node.js", "Python/FastAPI", "Python/Django"]
159
  )
 
160
  features = st.multiselect(
161
  "Features",
162
- ["Authentication", "REST API", "File Upload", "Real-time Updates"]
163
  )
164
-
165
- submitted = st.form_submit_button("Generate Code")
166
 
167
- if submitted:
 
 
168
  try:
169
- with st.spinner("Generating your application..."):
170
- result = st.session_state.rag_system.generate_code(
171
  description,
172
  {
173
  "frontend": frontend,
@@ -177,17 +212,32 @@ def main():
177
  features
178
  )
179
 
180
- # Display results
181
- tabs = st.tabs(["Frontend", "Backend", "Database", "Setup"])
 
 
 
 
 
182
 
183
  with tabs[0]:
184
- st.code(result["frontend"])
 
185
  with tabs[1]:
186
- st.code(result["backend"])
 
187
  with tabs[2]:
188
- st.code(result["database"])
 
189
  with tabs[3]:
190
- st.markdown(result["setup"])
 
 
 
 
 
 
 
191
 
192
  except Exception as e:
193
  st.error(f"An error occurred: {str(e)}")
 
2
  from langchain_community.embeddings import HuggingFaceEmbeddings
3
  from langchain_community.vectorstores import Chroma
4
  from langchain.text_splitter import RecursiveCharacterTextSplitter
5
+ from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
6
+ import torch
 
 
 
7
  import logging
8
+ from pathlib import Path
9
 
10
+ class LocalWebDevAssistant:
11
  def __init__(self):
12
  self.initialize_logging()
13
+ self.setup_model()
14
  self.setup_embeddings()
 
15
  self.initialize_vector_store()
16
 
17
  def initialize_logging(self):
18
  logging.basicConfig(level=logging.INFO)
19
  self.logger = logging.getLogger(__name__)
20
 
21
+ def setup_model(self):
22
+ # Using a smaller, directly available model
23
+ model_name = "facebook/opt-350m" # Smaller model that's good for code
24
+
25
+ @st.cache_resource
26
+ def load_model_and_tokenizer(model_name):
27
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
28
+ model = AutoModelForCausalLM.from_pretrained(
29
+ model_name,
30
+ torch_dtype=torch.float16,
31
+ low_cpu_mem_usage=True
32
+ )
33
+ return model, tokenizer
34
+
35
+ self.model, self.tokenizer = load_model_and_tokenizer(model_name)
36
+ self.generator = pipeline(
37
+ "text-generation",
38
+ model=self.model,
39
+ tokenizer=self.tokenizer,
40
+ max_length=1000
41
+ )
42
+
43
  def setup_embeddings(self):
44
  self.embeddings = HuggingFaceEmbeddings(
45
  model_name="all-MiniLM-L6-v2",
46
  model_kwargs={'device': 'cpu'}
47
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
 
49
  def initialize_vector_store(self):
50
+ if not Path("chroma_db").exists():
51
+ self.create_knowledge_base()
52
+
53
+ self.vector_store = Chroma(
54
+ persist_directory="chroma_db",
55
+ embedding_function=self.embeddings
56
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
 
58
+ def create_knowledge_base(self):
59
+ # Basic web development knowledge
60
  documents = [
61
+ """React component structure:
62
+ import React from 'react';
63
+
64
+ const Component = ({ props }) => {
65
+ return (
66
+ <div>
67
+ {/* Component content */}
68
+ </div>
69
+ );
70
+ };
71
+
72
+ export default Component;
73
+ """,
74
+ """FastAPI backend structure:
75
+ from fastapi import FastAPI
76
+
77
+ app = FastAPI()
78
+
79
+ @app.get("/")
80
+ async def root():
81
+ return {"message": "Hello World"}
82
+ """,
83
+ """MongoDB connection:
84
+ from pymongo import MongoClient
85
+
86
+ client = MongoClient('mongodb://localhost:27017/')
87
+ db = client['database_name']
88
+ """
89
  ]
90
 
91
  text_splitter = RecursiveCharacterTextSplitter(
92
+ chunk_size=500,
93
+ chunk_overlap=50
94
  )
95
  texts = text_splitter.split_text('\n\n'.join(documents))
96
 
97
+ Chroma.from_texts(
98
  texts,
99
  self.embeddings,
100
  persist_directory="chroma_db"
101
  )
 
102
 
103
+ def generate_code(self, description, tech_stack, features):
104
  try:
105
+ # Create prompt
106
+ prompt = f"""Generate code for a web application with:
107
  Description: {description}
108
+ Technology Stack: {tech_stack}
109
+ Features: {features}
110
 
111
+ Provide the code in sections:
 
 
 
 
112
  """
113
 
114
+ # Get relevant context from vector store
115
+ docs = self.vector_store.similarity_search(description, k=2)
116
+ context = "\n".join(doc.page_content for doc in docs)
117
+
118
+ # Generate with context
119
+ full_prompt = f"{context}\n{prompt}"
120
+
121
+ response = self.generator(
122
+ full_prompt,
123
+ max_length=1000,
124
+ num_return_sequences=1
125
+ )[0]['generated_text']
126
+
127
  return self.process_response(response)
128
 
129
  except Exception as e:
 
131
  raise
132
 
133
  def process_response(self, response):
134
+ # Extract different code sections
135
+ sections = {
136
+ "frontend": "",
137
+ "backend": "",
138
+ "database": "",
139
+ "instructions": ""
140
  }
141
+
142
+ current_section = "frontend"
143
+ for line in response.split('\n'):
144
+ if "FRONTEND:" in line.upper():
145
+ current_section = "frontend"
146
+ continue
147
+ elif "BACKEND:" in line.upper():
148
+ current_section = "backend"
149
+ continue
150
+ elif "DATABASE:" in line.upper():
151
+ current_section = "database"
152
+ continue
153
+ elif "INSTRUCTIONS:" in line.upper():
154
+ current_section = "instructions"
155
+ continue
156
+
157
+ sections[current_section] += line + '\n'
158
+
159
+ return sections
160
 
161
  def main():
162
+ st.set_page_config(page_title="Web Development Assistant", layout="wide")
163
 
164
+ st.title("🚀 Web Development Assistant")
165
+ st.write("Generate web application code using AI")
166
+
167
+ if 'assistant' not in st.session_state:
168
+ with st.spinner("Initializing... (this may take a minute)"):
169
+ st.session_state.assistant = LocalWebDevAssistant()
170
+
171
+ with st.form("project_details"):
172
  description = st.text_area(
173
  "Project Description",
174
  placeholder="Describe your web application..."
 
177
  col1, col2 = st.columns(2)
178
  with col1:
179
  frontend = st.selectbox(
180
+ "Frontend Framework",
181
+ ["React", "Vue", "Plain JavaScript"]
182
  )
183
+
184
+ backend = st.selectbox(
185
+ "Backend Framework",
186
+ ["FastAPI", "Express", "Flask"]
187
  )
188
 
189
  with col2:
190
+ database = st.selectbox(
191
+ "Database",
192
+ ["MongoDB", "PostgreSQL", "SQLite"]
193
  )
194
+
195
  features = st.multiselect(
196
  "Features",
197
+ ["Authentication", "REST API", "Database CRUD", "Form Handling"]
198
  )
 
 
199
 
200
+ generate = st.form_submit_button("Generate Code")
201
+
202
+ if generate:
203
  try:
204
+ with st.spinner("Generating code..."):
205
+ result = st.session_state.assistant.generate_code(
206
  description,
207
  {
208
  "frontend": frontend,
 
212
  features
213
  )
214
 
215
+ # Display results in tabs
216
+ tabs = st.tabs([
217
+ "Frontend Code",
218
+ "Backend Code",
219
+ "Database Setup",
220
+ "Instructions"
221
+ ])
222
 
223
  with tabs[0]:
224
+ st.code(result["frontend"], language="javascript")
225
+
226
  with tabs[1]:
227
+ st.code(result["backend"], language="python")
228
+
229
  with tabs[2]:
230
+ st.code(result["database"], language="sql")
231
+
232
  with tabs[3]:
233
+ st.markdown(result["instructions"])
234
+
235
+ # Add download button
236
+ st.download_button(
237
+ "Download Code",
238
+ '\n\n'.join(result.values()),
239
+ file_name="generated_code.txt"
240
+ )
241
 
242
  except Exception as e:
243
  st.error(f"An error occurred: {str(e)}")