Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import difflib | |
| import os | |
| import re | |
| import hashlib | |
| from groq import Groq | |
| # --- Page config --- | |
| st.set_page_config(page_title="π AI Assistant with Workflow + Semantic Search", layout="wide") | |
| # --- Groq API Setup --- | |
| GROQ_API_KEY = os.environ.get("GROQ_API_KEY") | |
| if not GROQ_API_KEY: | |
| st.error("β Please set your GROQ_API_KEY environment variable.") | |
| st.stop() | |
| client = Groq(api_key=GROQ_API_KEY) | |
| # --- Cache for embeddings --- | |
| embedding_cache = {} | |
| def get_embedding(text): | |
| key = hashlib.sha256(text.encode()).hexdigest() | |
| if key in embedding_cache: | |
| return embedding_cache[key] | |
| embedding = [ord(c) % 100 / 100 for c in text[:512]] | |
| embedding_cache[key] = embedding | |
| return embedding | |
| def cosine_similarity(vec1, vec2): | |
| dot = sum(a*b for a,b in zip(vec1, vec2)) | |
| norm1 = sum(a*a for a in vec1) ** 0.5 | |
| norm2 = sum(b*b for b in vec2) ** 0.5 | |
| return dot / (norm1 * norm2 + 1e-8) | |
| def split_code_into_chunks(code, lang): | |
| if lang.lower() == "python": | |
| pattern = r'(def\s+\w+\(.*?\):|class\s+\w+\(?.*?\)?:)' | |
| splits = re.split(pattern, code) | |
| chunks = [] | |
| for i in range(1, len(splits), 2): | |
| header = splits[i] | |
| body = splits[i+1] if (i+1) < len(splits) else "" | |
| chunks.append(header + body) | |
| return chunks if chunks else [code] | |
| else: | |
| return [code] | |
| def groq_call(prompt): | |
| resp = client.chat.completions.create( | |
| messages=[{"role": "user", "content": prompt}], | |
| model="llama3-70b-8192", | |
| ) | |
| return resp.choices[0].message.content | |
| def semantic_search_improved(code, question, lang, skill, role, explain_lang): | |
| chunks = split_code_into_chunks(code, lang) | |
| question_emb = get_embedding(question) | |
| scored_chunks = [] | |
| for chunk in chunks: | |
| emb = get_embedding(chunk) | |
| score = cosine_similarity(question_emb, emb) | |
| scored_chunks.append((score, chunk)) | |
| scored_chunks.sort(key=lambda x: x[0], reverse=True) | |
| top_chunks = [c for _, c in scored_chunks[:3]] | |
| combined_code = "\n\n".join(top_chunks) | |
| prompt = ( | |
| f"You are a friendly and insightful {lang} expert helping a {skill} {role}.\n" | |
| f"Based on these relevant code snippets:\n{combined_code}\n" | |
| f"Answer this question in {explain_lang}:\n{question}\n" | |
| f"Explain which parts handle the question and how to modify them if needed." | |
| ) | |
| return groq_call(prompt) | |
| def error_detection_and_fixes(refactored_code, lang, skill, role, explain_lang): | |
| prompt = ( | |
| f"You are a senior {lang} developer. Analyze this code for bugs, security flaws, " | |
| f"and performance issues. Suggest fixes with explanations in {explain_lang}:\n\n{refactored_code}" | |
| ) | |
| return groq_call(prompt) | |
| def agentic_workflow(code, skill_level, programming_language, explanation_language, user_role): | |
| timeline = [] | |
| suggestions = [] | |
| # Explanation | |
| explain_prompt = ( | |
| f"You are a friendly and insightful {programming_language} expert helping a {skill_level} {user_role}. " | |
| f"Explain this code in {explanation_language} with clear examples, analogies, and why each part matters:\n\n{code}" | |
| ) | |
| explanation = groq_call(explain_prompt) | |
| timeline.append({"step": "Explain", "description": "Detailed explanation", "output": explanation, "code": code}) | |
| suggestions.append("Consider refactoring your code to improve readability and performance.") | |
| # Refactor | |
| refactor_prompt = ( | |
| f"Refactor this {programming_language} code. Explain the changes like a mentor helping a {skill_level} {user_role}. " | |
| f"Include best practices and improvements:\n\n{code}" | |
| ) | |
| refactor_response = groq_call(refactor_prompt) | |
| if "```" in refactor_response: | |
| parts = refactor_response.split("```") | |
| refactored_code = "" | |
| for part in parts: | |
| if part.strip().startswith(programming_language.lower()): | |
| refactored_code = part.strip().split('\n', 1)[1] if '\n' in part else "" | |
| break | |
| if not refactored_code: | |
| refactored_code = refactor_response | |
| else: | |
| refactored_code = refactor_response | |
| timeline.append({"step": "Refactor", "description": "Refactored code with improvements", "output": refactored_code, "code": refactored_code}) | |
| suggestions.append("Review the refactored code and adapt it to your style or project needs.") | |
| # Review | |
| review_prompt = ( | |
| f"As a senior {programming_language} developer, review the refactored code. " | |
| f"Give constructive feedback on strengths, weaknesses, performance, security, and improvements in {explanation_language}:\n\n{refactored_code}" | |
| ) | |
| review = groq_call(review_prompt) | |
| timeline.append({"step": "Review", "description": "Code review and suggestions", "output": review, "code": refactored_code}) | |
| suggestions.append("Incorporate review feedback for cleaner, robust code.") | |
| # Error detection & fixes | |
| errors = error_detection_and_fixes(refactored_code, programming_language, skill_level, user_role, explanation_language) | |
| timeline.append({"step": "Error Detection", "description": "Bugs, security, performance suggestions", "output": errors, "code": refactored_code}) | |
| suggestions.append("Apply fixes to improve code safety and performance.") | |
| # Test generation | |
| test_prompt = ( | |
| f"Write clear, effective unit tests for this {programming_language} code. " | |
| f"Explain what each test does in {explanation_language}, for a {skill_level} {user_role}:\n\n{refactored_code}" | |
| ) | |
| tests = groq_call(test_prompt) | |
| timeline.append({"step": "Test Generation", "description": "Generated unit tests", "output": tests, "code": tests}) | |
| suggestions.append("Run generated tests locally to validate changes.") | |
| return timeline, suggestions | |
| def get_inline_diff_html(original, modified): | |
| differ = difflib.HtmlDiff(tabsize=4, wrapcolumn=80) | |
| html = differ.make_table( | |
| original.splitlines(), modified.splitlines(), | |
| "Original", "Refactored", context=True, numlines=2 | |
| ) | |
| return f'<div style="overflow-x:auto; max-height:400px;">{html}</div>' | |
| def detect_code_type(code, programming_language): | |
| backend_keywords = [ | |
| 'flask', 'django', 'express', 'fastapi', 'spring', 'controller', 'api', 'server', 'database', 'sql', 'mongoose' | |
| ] | |
| frontend_keywords = [ | |
| 'react', 'vue', 'angular', 'component', 'html', 'css', 'document.getelementbyid', 'window.', 'render', 'jsx', | |
| '<html', '<body', '<script', '<div', 'getelementbyid', 'queryselector', 'addeventlistener', 'innerhtml' | |
| ] | |
| data_science_keywords = [ | |
| 'pandas', 'numpy', 'sklearn', 'matplotlib', 'seaborn', 'plt', 'train_test_split', 'randomforestclassifier', 'classification_report' | |
| ] | |
| code_lower = code.lower() | |
| if any(word in code_lower for word in data_science_keywords): | |
| return 'data_science' | |
| if any(word in code_lower for word in frontend_keywords): | |
| return 'frontend' | |
| if programming_language.lower() in ['python', 'java', 'c#']: | |
| if any(word in code_lower for word in backend_keywords): | |
| return 'backend' | |
| if programming_language.lower() in ['javascript', 'typescript', 'java', 'c#']: | |
| if any(word in code_lower for word in frontend_keywords): | |
| return 'frontend' | |
| if programming_language.lower() in ['python', 'java', 'c#']: | |
| return 'backend' | |
| if programming_language.lower() in ['javascript', 'typescript']: | |
| return 'frontend' | |
| return 'unknown' | |
| def code_complexity(code): | |
| lines = code.count('\n') + 1 | |
| functions = code.count('def ') | |
| classes = code.count('class ') | |
| comments = code.count('#') | |
| return f"Lines: {lines}, Functions: {functions}, Classes: {classes}, Comments: {comments}" | |
| def code_matches_language(code: str, language: str) -> bool: | |
| """Strictly check whether code matches key patterns of the selected language.""" | |
| code_lower = code.strip().lower() | |
| language = language.lower() | |
| patterns = { | |
| "python": [ | |
| "def ", "class ", "import ", "from ", "try:", "except", "raise", "lambda", | |
| "with ", "yield", "async ", "await", "print(", "self.", "__init__", "__name__", | |
| "if __name__ == '__main__':", "#!", # shebang for executable scripts | |
| ], | |
| "c++": [ | |
| "#include", "int main(", "std::", "::", "cout <<", "cin >>", "new ", "delete ", | |
| "try {", "catch(", "template<", "using namespace", "class ", "struct ", "#define", | |
| ], | |
| "java": [ | |
| "package ", "import java.", "public class", "private ", "protected ", "public static void main", | |
| "System.out.println", "try {", "catch(", "throw new ", "implements ", "extends ", | |
| "@Override", "interface ", "enum ", "synchronized ", "final ", | |
| ], | |
| "c#": [ | |
| "using System", "namespace ", "class ", "interface ", "public static void Main", | |
| "Console.WriteLine", "try {", "catch(", "throw ", "async ", "await ", "get;", "set;", | |
| "List<", "Dictionary<", "[Serializable]", "[Obsolete]", | |
| ], | |
| "javascript": [ | |
| "function ", "const ", "let ", "var ", "document.", "window.", "console.log", | |
| "if(", "for(", "while(", "switch(", "try {", "catch(", "export ", "import ", "async ", | |
| "await ", "=>", "this.", "class ", "prototype", "new ", "$(", | |
| ], | |
| "typescript": [ | |
| "function ", "const ", "let ", "interface ", "type ", ": string", ": number", ": boolean", | |
| "implements ", "extends ", "enum ", "public ", "private ", "protected ", "readonly ", | |
| "import ", "export ", "console.log", "async ", "await ", "=>", "this.", | |
| ], | |
| "html": [ | |
| "<!doctype html", "<html", "<head>", "<body>", "<script", "<style", "<meta ", "<link ", | |
| "<title>", "<div", "<span", "<p>", "<h1>", "<ul>", "<li>", "<form", "<input", "<button", | |
| "<table", "<footer", "<header", "<section", "<article", "<nav", "<img", "<a ", "</html>", | |
| ], | |
| } | |
| match_patterns = patterns.get(language, []) | |
| match_count = sum(1 for pattern in match_patterns if pattern in code_lower) | |
| # Require at least one pattern to match for validation to succeed | |
| return match_count >= 1 | |
| # --- Sidebar --- | |
| st.sidebar.title("π§ Configuration") | |
| lang = st.sidebar.selectbox("Programming Language", ["Python", "JavaScript", "C++", "Java", "C#", "TypeScript"]) | |
| skill = st.sidebar.selectbox("Skill Level", ["Beginner", "Intermediate", "Expert"]) | |
| role = st.sidebar.selectbox("Your Role", ["Student", "Frontend Developer", "Backend Developer", "Data Scientist"]) | |
| explain_lang = st.sidebar.selectbox("Explanation Language", ["English", "Spanish", "Chinese", "Urdu"]) | |
| st.sidebar.markdown("---") | |
| st.sidebar.markdown("<span style='color:#fff;'>Powered by <b>BLACKBOX.AI</b></span>", unsafe_allow_html=True) | |
| tabs = st.tabs(["π§ Full AI Workflow", "π Semantic Search"]) | |
| # --- Tab 1: Full AI Workflow --- | |
| with tabs[0]: | |
| st.title("π§ Full AI Workflow") | |
| file_types = { | |
| "Python": ["py"], | |
| "JavaScript": ["js"], | |
| "C++": ["cpp", "h", "hpp"], | |
| "Java": ["java"], | |
| "C#": ["cs"], | |
| "TypeScript": ["ts"], | |
| } | |
| uploaded_file = st.file_uploader( | |
| f"Upload {', '.join(file_types.get(lang, []))} file(s)", | |
| type=file_types.get(lang, None) | |
| ) | |
| if uploaded_file: | |
| code_input = uploaded_file.read().decode("utf-8") | |
| else: | |
| code_input = st.text_area("Your Code", height=300, placeholder="Paste your code here...") | |
| if code_input: | |
| st.markdown(f"<b>Complexity:</b> {code_complexity(code_input)}", unsafe_allow_html=True) | |
| if st.button("Run AI Workflow"): | |
| if not code_input.strip(): | |
| st.warning("Please paste or upload your code.") | |
| elif not code_matches_language(code_input, lang): | |
| st.error(f"The pasted code doesnβt look like valid {lang} code. Please check your code or select the correct language.") | |
| else: | |
| code_type = detect_code_type(code_input, lang) | |
| if code_type == "data_science" and role != "Data Scientist": | |
| st.error("Data science code detected. Please select 'Data Scientist' role.") | |
| elif code_type == "frontend" and role != "Frontend Developer": | |
| st.error("Frontend code detected. Please select 'Frontend Developer' role.") | |
| elif code_type == "backend" and role != "Backend Developer": | |
| st.error("Backend code detected. Please select 'Backend Developer' role.") | |
| else: | |
| with st.spinner("Running agentic workflow..."): | |
| timeline, suggestions = agentic_workflow(code_input, skill, lang, explain_lang, role) | |
| # Show each step in an expander | |
| for step in timeline: | |
| with st.expander(f"β {step['step']} - {step['description']}"): | |
| if step['step'] == "Refactor": | |
| diff_html = get_inline_diff_html(code_input, step['code']) | |
| st.markdown(diff_html, unsafe_allow_html=True) | |
| st.code(step['output'], language=lang.lower()) | |
| else: | |
| st.markdown(step['output']) | |
| st.markdown("#### Agent Suggestions") | |
| for s in suggestions: | |
| st.markdown(f"- {s}") | |
| # Download buttons after suggestions | |
| st.markdown("---") | |
| st.markdown("### π₯ Download Results") | |
| report_text = "" | |
| for step in timeline: | |
| report_text += f"## {step['step']}\n{step['description']}\n\n{step['output']}\n\n" | |
| st.download_button( | |
| label="π Download Full Workflow Report", | |
| data=report_text, | |
| file_name="ai_workflow_report.txt", | |
| mime="text/plain", | |
| ) | |
| # --- Tab 2: Semantic Search --- | |
| with tabs[1]: | |
| st.title("π Semantic Search") | |
| sem_code = st.text_area("Your Code", height=300, placeholder="Paste your code...") | |
| sem_q = st.text_input("Your Question", placeholder="E.g., What does this function do?") | |
| if st.button("Run Semantic Search"): | |
| if not sem_code.strip() or not sem_q.strip(): | |
| st.warning("Code and question required.") | |
| else: | |
| with st.spinner("Running semantic search..."): | |
| answer = semantic_search_improved(sem_code, sem_q, lang, skill, role, explain_lang) | |
| st.markdown("### π Answer") | |
| st.markdown(answer) | |
| st.markdown("---") | |