Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,246 +1,242 @@
|
|
| 1 |
import streamlit as st
|
| 2 |
-
from
|
| 3 |
-
|
| 4 |
-
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
| 5 |
-
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
|
| 6 |
-
import torch
|
| 7 |
-
import logging
|
| 8 |
from pathlib import Path
|
|
|
|
| 9 |
|
| 10 |
-
class
|
| 11 |
def __init__(self):
|
| 12 |
-
|
| 13 |
-
self.setup_model()
|
| 14 |
-
self.setup_embeddings()
|
| 15 |
-
self.initialize_vector_store()
|
| 16 |
-
|
| 17 |
-
def initialize_logging(self):
|
| 18 |
-
logging.basicConfig(level=logging.INFO)
|
| 19 |
-
self.logger = logging.getLogger(__name__)
|
| 20 |
-
|
| 21 |
-
def setup_model(self):
|
| 22 |
-
# Using a smaller, directly available model
|
| 23 |
-
model_name = "facebook/opt-350m" # Smaller model that's good for code
|
| 24 |
-
|
| 25 |
-
@st.cache_resource
|
| 26 |
-
def load_model_and_tokenizer(model_name):
|
| 27 |
-
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 28 |
-
model = AutoModelForCausalLM.from_pretrained(
|
| 29 |
-
model_name,
|
| 30 |
-
torch_dtype=torch.float16,
|
| 31 |
-
low_cpu_mem_usage=True
|
| 32 |
-
)
|
| 33 |
-
return model, tokenizer
|
| 34 |
-
|
| 35 |
-
self.model, self.tokenizer = load_model_and_tokenizer(model_name)
|
| 36 |
self.generator = pipeline(
|
| 37 |
-
|
| 38 |
-
model=
|
| 39 |
-
|
| 40 |
-
max_length=1000
|
| 41 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 42 |
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
model_name="all-MiniLM-L6-v2",
|
| 46 |
-
model_kwargs={'device': 'cpu'}
|
| 47 |
-
)
|
| 48 |
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
self.vector_store = Chroma(
|
| 54 |
-
persist_directory="chroma_db",
|
| 55 |
-
embedding_function=self.embeddings
|
| 56 |
-
)
|
| 57 |
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
{/* Component content */}
|
| 68 |
-
</div>
|
| 69 |
-
);
|
| 70 |
-
};
|
| 71 |
-
|
| 72 |
-
export default Component;
|
| 73 |
-
""",
|
| 74 |
-
"""FastAPI backend structure:
|
| 75 |
-
from fastapi import FastAPI
|
| 76 |
-
|
| 77 |
-
app = FastAPI()
|
| 78 |
-
|
| 79 |
-
@app.get("/")
|
| 80 |
-
async def root():
|
| 81 |
-
return {"message": "Hello World"}
|
| 82 |
-
""",
|
| 83 |
-
"""MongoDB connection:
|
| 84 |
-
from pymongo import MongoClient
|
| 85 |
-
|
| 86 |
-
client = MongoClient('mongodb://localhost:27017/')
|
| 87 |
-
db = client['database_name']
|
| 88 |
-
"""
|
| 89 |
-
]
|
| 90 |
-
|
| 91 |
-
text_splitter = RecursiveCharacterTextSplitter(
|
| 92 |
-
chunk_size=500,
|
| 93 |
-
chunk_overlap=50
|
| 94 |
-
)
|
| 95 |
-
texts = text_splitter.split_text('\n\n'.join(documents))
|
| 96 |
-
|
| 97 |
-
Chroma.from_texts(
|
| 98 |
-
texts,
|
| 99 |
-
self.embeddings,
|
| 100 |
-
persist_directory="chroma_db"
|
| 101 |
-
)
|
| 102 |
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 140 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 141 |
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
current_section = "frontend"
|
| 146 |
-
continue
|
| 147 |
-
elif "BACKEND:" in line.upper():
|
| 148 |
-
current_section = "backend"
|
| 149 |
-
continue
|
| 150 |
-
elif "DATABASE:" in line.upper():
|
| 151 |
-
current_section = "database"
|
| 152 |
-
continue
|
| 153 |
-
elif "INSTRUCTIONS:" in line.upper():
|
| 154 |
-
current_section = "instructions"
|
| 155 |
-
continue
|
| 156 |
-
|
| 157 |
-
sections[current_section] += line + '\n'
|
| 158 |
|
| 159 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 160 |
|
| 161 |
def main():
|
| 162 |
-
st.set_page_config(page_title="
|
| 163 |
|
| 164 |
-
|
| 165 |
-
|
|
|
|
| 166 |
|
| 167 |
-
|
| 168 |
-
with st.spinner("Initializing... (this may take a minute)"):
|
| 169 |
-
st.session_state.assistant = LocalWebDevAssistant()
|
| 170 |
|
| 171 |
-
with st.form("
|
| 172 |
-
description = st.text_area(
|
| 173 |
-
"Project Description",
|
| 174 |
-
placeholder="Describe your web application..."
|
| 175 |
-
)
|
| 176 |
-
|
| 177 |
col1, col2 = st.columns(2)
|
|
|
|
| 178 |
with col1:
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
)
|
| 183 |
-
|
| 184 |
-
backend = st.selectbox(
|
| 185 |
-
"Backend Framework",
|
| 186 |
-
["FastAPI", "Express", "Flask"]
|
| 187 |
-
)
|
| 188 |
|
| 189 |
with col2:
|
| 190 |
-
database = st.selectbox(
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 194 |
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 198 |
)
|
| 199 |
-
|
| 200 |
-
generate = st.form_submit_button("Generate Code")
|
| 201 |
-
|
| 202 |
-
if generate:
|
| 203 |
-
try:
|
| 204 |
-
with st.spinner("Generating code..."):
|
| 205 |
-
result = st.session_state.assistant.generate_code(
|
| 206 |
-
description,
|
| 207 |
-
{
|
| 208 |
-
"frontend": frontend,
|
| 209 |
-
"backend": backend,
|
| 210 |
-
"database": database
|
| 211 |
-
},
|
| 212 |
-
features
|
| 213 |
-
)
|
| 214 |
-
|
| 215 |
-
# Display results in tabs
|
| 216 |
-
tabs = st.tabs([
|
| 217 |
-
"Frontend Code",
|
| 218 |
-
"Backend Code",
|
| 219 |
-
"Database Setup",
|
| 220 |
-
"Instructions"
|
| 221 |
-
])
|
| 222 |
-
|
| 223 |
-
with tabs[0]:
|
| 224 |
-
st.code(result["frontend"], language="javascript")
|
| 225 |
-
|
| 226 |
-
with tabs[1]:
|
| 227 |
-
st.code(result["backend"], language="python")
|
| 228 |
-
|
| 229 |
-
with tabs[2]:
|
| 230 |
-
st.code(result["database"], language="sql")
|
| 231 |
-
|
| 232 |
-
with tabs[3]:
|
| 233 |
-
st.markdown(result["instructions"])
|
| 234 |
-
|
| 235 |
-
# Add download button
|
| 236 |
-
st.download_button(
|
| 237 |
-
"Download Code",
|
| 238 |
-
'\n\n'.join(result.values()),
|
| 239 |
-
file_name="generated_code.txt"
|
| 240 |
-
)
|
| 241 |
-
|
| 242 |
-
except Exception as e:
|
| 243 |
-
st.error(f"An error occurred: {str(e)}")
|
| 244 |
|
| 245 |
if __name__ == "__main__":
|
| 246 |
main()
|
|
|
|
| 1 |
import streamlit as st
|
| 2 |
+
from transformers import pipeline
|
| 3 |
+
import json
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
from pathlib import Path
|
| 5 |
+
import os
|
| 6 |
|
| 7 |
+
class FastWebGenerator:
|
| 8 |
def __init__(self):
|
| 9 |
+
# Load the smallest model possible for quick text completion
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
self.generator = pipeline(
|
| 11 |
+
'text-generation',
|
| 12 |
+
model='distilgpt2',
|
| 13 |
+
device_map='auto'
|
|
|
|
| 14 |
)
|
| 15 |
+
|
| 16 |
+
# Load templates
|
| 17 |
+
self.templates = {
|
| 18 |
+
"react": {
|
| 19 |
+
"component": """
|
| 20 |
+
import React, { useState, useEffect } from 'react';
|
| 21 |
|
| 22 |
+
const {component_name} = () => {
|
| 23 |
+
const [data, setData] = useState([]);
|
|
|
|
|
|
|
|
|
|
| 24 |
|
| 25 |
+
useEffect(() => {
|
| 26 |
+
// Fetch data
|
| 27 |
+
fetchData();
|
| 28 |
+
}, []);
|
|
|
|
|
|
|
|
|
|
|
|
|
| 29 |
|
| 30 |
+
const fetchData = async () => {
|
| 31 |
+
try {
|
| 32 |
+
const response = await fetch('/api/data');
|
| 33 |
+
const jsonData = await response.json();
|
| 34 |
+
setData(jsonData);
|
| 35 |
+
} catch (error) {
|
| 36 |
+
console.error('Error:', error);
|
| 37 |
+
}
|
| 38 |
+
};
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 39 |
|
| 40 |
+
return (
|
| 41 |
+
<div className="container mx-auto p-4">
|
| 42 |
+
<h1 className="text-2xl font-bold mb-4">{title}</h1>
|
| 43 |
+
{content}
|
| 44 |
+
</div>
|
| 45 |
+
);
|
| 46 |
+
};
|
| 47 |
+
|
| 48 |
+
export default {component_name};
|
| 49 |
+
""",
|
| 50 |
+
"api": """
|
| 51 |
+
async function fetchData() {
|
| 52 |
+
const response = await fetch('/api/endpoint');
|
| 53 |
+
const data = await response.json();
|
| 54 |
+
return data;
|
| 55 |
+
}
|
| 56 |
+
|
| 57 |
+
async function postData(data) {
|
| 58 |
+
const response = await fetch('/api/endpoint', {
|
| 59 |
+
method: 'POST',
|
| 60 |
+
headers: {
|
| 61 |
+
'Content-Type': 'application/json',
|
| 62 |
+
},
|
| 63 |
+
body: JSON.stringify(data),
|
| 64 |
+
});
|
| 65 |
+
return response.json();
|
| 66 |
+
}
|
| 67 |
+
"""
|
| 68 |
+
},
|
| 69 |
+
"fastapi": {
|
| 70 |
+
"main": """
|
| 71 |
+
from fastapi import FastAPI, HTTPException
|
| 72 |
+
from pydantic import BaseModel
|
| 73 |
+
from typing import List, Optional
|
| 74 |
+
import uvicorn
|
| 75 |
+
|
| 76 |
+
app = FastAPI()
|
| 77 |
+
|
| 78 |
+
class ItemModel(BaseModel):
|
| 79 |
+
id: Optional[int]
|
| 80 |
+
name: str
|
| 81 |
+
description: Optional[str]
|
| 82 |
+
|
| 83 |
+
@app.get("/")
|
| 84 |
+
async def root():
|
| 85 |
+
return {"message": "API is running"}
|
| 86 |
+
|
| 87 |
+
@app.get("/items")
|
| 88 |
+
async def get_items():
|
| 89 |
+
return {"items": items}
|
| 90 |
+
|
| 91 |
+
@app.post("/items")
|
| 92 |
+
async def create_item(item: ItemModel):
|
| 93 |
+
items.append(item)
|
| 94 |
+
return item
|
| 95 |
+
|
| 96 |
+
if __name__ == "__main__":
|
| 97 |
+
uvicorn.run(app, host="0.0.0.0", port=8000)
|
| 98 |
+
""",
|
| 99 |
+
"crud": """
|
| 100 |
+
@app.get("/items/{item_id}")
|
| 101 |
+
async def get_item(item_id: int):
|
| 102 |
+
if item_id < 0 or item_id >= len(items):
|
| 103 |
+
raise HTTPException(status_code=404, detail="Item not found")
|
| 104 |
+
return items[item_id]
|
| 105 |
+
|
| 106 |
+
@app.put("/items/{item_id}")
|
| 107 |
+
async def update_item(item_id: int, item: ItemModel):
|
| 108 |
+
if item_id < 0 or item_id >= len(items):
|
| 109 |
+
raise HTTPException(status_code=404, detail="Item not found")
|
| 110 |
+
items[item_id] = item
|
| 111 |
+
return item
|
| 112 |
+
|
| 113 |
+
@app.delete("/items/{item_id}")
|
| 114 |
+
async def delete_item(item_id: int):
|
| 115 |
+
if item_id < 0 or item_id >= len(items):
|
| 116 |
+
raise HTTPException(status_code=404, detail="Item not found")
|
| 117 |
+
item = items.pop(item_id)
|
| 118 |
+
return item
|
| 119 |
+
"""
|
| 120 |
+
},
|
| 121 |
+
"mongodb": """
|
| 122 |
+
from motor.motor_asyncio import AsyncIOMotorClient
|
| 123 |
+
|
| 124 |
+
client = AsyncIOMotorClient('mongodb://localhost:27017')
|
| 125 |
+
db = client.database_name
|
| 126 |
+
|
| 127 |
+
async def get_all():
|
| 128 |
+
cursor = db.collection.find({})
|
| 129 |
+
return await cursor.to_list(length=100)
|
| 130 |
+
|
| 131 |
+
async def get_one(id: str):
|
| 132 |
+
return await db.collection.find_one({"_id": id})
|
| 133 |
+
|
| 134 |
+
async def create_one(data: dict):
|
| 135 |
+
result = await db.collection.insert_one(data)
|
| 136 |
+
return result.inserted_id
|
| 137 |
+
|
| 138 |
+
async def update_one(id: str, data: dict):
|
| 139 |
+
result = await db.collection.update_one(
|
| 140 |
+
{"_id": id},
|
| 141 |
+
{"$set": data}
|
| 142 |
+
)
|
| 143 |
+
return result.modified_count
|
| 144 |
+
|
| 145 |
+
async def delete_one(id: str):
|
| 146 |
+
result = await db.collection.delete_one({"_id": id})
|
| 147 |
+
return result.deleted_count
|
| 148 |
+
"""
|
| 149 |
}
|
| 150 |
+
|
| 151 |
+
def customize_code(self, template, params):
|
| 152 |
+
# Quick text customization using the model
|
| 153 |
+
prompt = f"Customize this code: {template}\nWith these parameters: {params}\n"
|
| 154 |
+
result = self.generator(prompt, max_length=100, num_return_sequences=1)[0]['generated_text']
|
| 155 |
+
return result
|
| 156 |
+
|
| 157 |
+
def generate_project(self, specs):
|
| 158 |
+
frontend_code = self.templates["react"]["component"].replace(
|
| 159 |
+
"{component_name}", specs["component_name"]
|
| 160 |
+
).replace(
|
| 161 |
+
"{title}", specs["title"]
|
| 162 |
+
).replace(
|
| 163 |
+
"{content}", specs["content"]
|
| 164 |
+
)
|
| 165 |
|
| 166 |
+
backend_code = self.templates["fastapi"]["main"]
|
| 167 |
+
if specs.get("crud", False):
|
| 168 |
+
backend_code += self.templates["fastapi"]["crud"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 169 |
|
| 170 |
+
db_code = self.templates["mongodb"] if specs["database"] == "MongoDB" else ""
|
| 171 |
+
|
| 172 |
+
return {
|
| 173 |
+
"frontend": frontend_code,
|
| 174 |
+
"backend": backend_code,
|
| 175 |
+
"database": db_code
|
| 176 |
+
}
|
| 177 |
|
| 178 |
def main():
|
| 179 |
+
st.set_page_config(page_title="Fast Code Generator", layout="wide")
|
| 180 |
|
| 181 |
+
if 'generator' not in st.session_state:
|
| 182 |
+
with st.spinner("Loading generator..."):
|
| 183 |
+
st.session_state.generator = FastWebGenerator()
|
| 184 |
|
| 185 |
+
st.title("⚡ Fast Code Generator")
|
|
|
|
|
|
|
| 186 |
|
| 187 |
+
with st.form("generate_form"):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 188 |
col1, col2 = st.columns(2)
|
| 189 |
+
|
| 190 |
with col1:
|
| 191 |
+
component_name = st.text_input("Component Name", "MyComponent")
|
| 192 |
+
title = st.text_input("Title", "My App")
|
| 193 |
+
content = st.text_area("Content", "<p>Your content here</p>")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 194 |
|
| 195 |
with col2:
|
| 196 |
+
database = st.selectbox("Database", ["MongoDB", "PostgreSQL"])
|
| 197 |
+
add_crud = st.checkbox("Add CRUD Operations", value=True)
|
| 198 |
+
add_auth = st.checkbox("Add Authentication", value=False)
|
| 199 |
+
|
| 200 |
+
if st.form_submit_button("Generate"):
|
| 201 |
+
specs = {
|
| 202 |
+
"component_name": component_name,
|
| 203 |
+
"title": title,
|
| 204 |
+
"content": content,
|
| 205 |
+
"database": database,
|
| 206 |
+
"crud": add_crud,
|
| 207 |
+
"auth": add_auth
|
| 208 |
+
}
|
| 209 |
+
|
| 210 |
+
result = st.session_state.generator.generate_project(specs)
|
| 211 |
|
| 212 |
+
# Display results
|
| 213 |
+
tabs = st.tabs(["Frontend", "Backend", "Database"])
|
| 214 |
+
|
| 215 |
+
with tabs[0]:
|
| 216 |
+
st.code(result["frontend"], language="javascript")
|
| 217 |
+
with tabs[1]:
|
| 218 |
+
st.code(result["backend"], language="python")
|
| 219 |
+
with tabs[2]:
|
| 220 |
+
st.code(result["database"], language="python")
|
| 221 |
+
|
| 222 |
+
# Add download button
|
| 223 |
+
combined_code = f"""
|
| 224 |
+
// Frontend Code
|
| 225 |
+
{result['frontend']}
|
| 226 |
+
|
| 227 |
+
// Backend Code
|
| 228 |
+
{result['backend']}
|
| 229 |
+
|
| 230 |
+
// Database Code
|
| 231 |
+
{result['database']}
|
| 232 |
+
"""
|
| 233 |
+
|
| 234 |
+
st.download_button(
|
| 235 |
+
"Download Code",
|
| 236 |
+
combined_code,
|
| 237 |
+
file_name="generated_code.txt",
|
| 238 |
+
mime="text/plain"
|
| 239 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 240 |
|
| 241 |
if __name__ == "__main__":
|
| 242 |
main()
|