Spaces:
Sleeping
Sleeping
| from langchain_openai import ChatOpenAI | |
| from langchain_ollama import ChatOllama | |
| from langchain_groq import ChatGroq | |
| from langgraph.checkpoint.sqlite.aio import AsyncSqliteSaver | |
| from langchain.text_splitter import RecursiveCharacterTextSplitter | |
| from langchain_community.document_loaders import WebBaseLoader | |
| # from langchain_community.vectorstores import Chroma | |
| from langchain_chroma import Chroma | |
| from langchain_community.embeddings import HuggingFaceBgeEmbeddings | |
| import pickle | |
| from langchain_core.prompts import ChatPromptTemplate | |
| from langchain_openai import ChatOpenAI | |
| from pydantic import BaseModel, Field | |
| from typing import List | |
| from typing_extensions import TypedDict | |
| from langgraph.graph import END, StateGraph, START | |
| import subprocess | |
| import time | |
| import re | |
| import json | |
| import os | |
| from dotenv import load_dotenv | |
| load_dotenv() | |
| # Add after your imports | |
| os.environ["TOKENIZERS_PARALLELISM"] = "false" | |
| # llm = ChatOllama(model="codestral") | |
| expt_llm = "gpt-4o-mini" | |
| llm = ChatOpenAI(temperature=0, model=expt_llm) | |
| ## Create retrieval from existing store | |
| # Load the existing vectorstore | |
| # Load an existing (saved) embedding model from a pickle file | |
| # model_path = "/Model/embedding_model.pkl" | |
| model_path = "embedding_model.pkl" | |
| with open(model_path, 'rb') as f: | |
| embedding_model = pickle.load(f) | |
| print("Loaded embedding model successfully") | |
| vectorstore = Chroma( | |
| collection_name="solcoder-chroma", | |
| embedding_function=embedding_model, | |
| persist_directory="solcoder-db-cpu" | |
| ) | |
| retriever = vectorstore.as_retriever() | |
| # Grader prompt | |
| code_gen_prompt = ChatPromptTemplate( | |
| [ | |
| ( | |
| "system", | |
| """<instructions> You are a coding assistant with expertise in Solana Blockchain ecosystem. \n | |
| Here is a set of Solana development documentation based on a user question: \n ------- \n {context} \n ------- \n | |
| Answer the user question based on the above provided documentation. Ensure any code you provide can be executed with all required imports and variables \n | |
| defined. Structure your answer: 1) a prefix describing the code solution, 2) the imports, 3) the functioning code block. \n | |
| Invoke the code tool to structure the output correctly. </instructions> \n Here is the user question:""", | |
| ), | |
| ("placeholder", "{messages}"), | |
| ] | |
| ) | |
| # Data model | |
| class code(BaseModel): | |
| """Schema for code solutions to questions about Solana development.""" | |
| prefix: str = Field(description="Description of the problem and approach") | |
| imports: str = Field(description="Code block import statements") | |
| code: str = Field(description="Code block not including import statements") | |
| language: str = Field(description="programming language the code is implemented") | |
| class Config: | |
| json_schema_extra = { | |
| "example": { | |
| "prefix": "To read the balance of an account from the Solana network, you can use the `@solana/web3.js` library.", | |
| "imports": 'import { clusterApiUrl, Connection, PublicKey, LAMPORTS_PER_SOL,} from "@solana/web3.js";', | |
| "code":"""const connection = new Connection(clusterApiUrl("devnet"), "confirmed"); | |
| const wallet = new PublicKey("nicktrLHhYzLmoVbuZQzHUTicd2sfP571orwo9jfc8c"); | |
| const balance = await connection.getBalance(wallet); | |
| console.log(`Balance: ${balance / LAMPORTS_PER_SOL} SOL`);""", | |
| "language":"typescript" | |
| } | |
| } | |
| # expt_llm = "codestral" | |
| # llm = ChatOllama(temperature=0, model=expt_llm) | |
| # Post-processing | |
| def format_docs(docs): | |
| return "\n\n".join(doc.page_content for doc in docs) | |
| structured_llm = code_gen_prompt | llm.with_structured_output(code, include_raw=True) | |
| # Optional: Check for errors in case tool use is flaky | |
| def check_llm_output(tool_output): | |
| """Check for parse error or failure to call the tool""" | |
| # Error with parsing | |
| if tool_output["parsing_error"]: | |
| # Report back output and parsing errors | |
| print("Parsing error!") | |
| raw_output = str(tool_output["raw"].content) | |
| error = tool_output["parsing_error"] | |
| raise ValueError( | |
| f"Error parsing your output! Be sure to invoke the tool. Output: {raw_output}. \n Parse error: {error}" | |
| ) | |
| # Tool was not invoked | |
| elif not tool_output["parsed"]: | |
| print("Failed to invoke tool!") | |
| raise ValueError( | |
| "You did not use the provided tool! Be sure to invoke the tool to structure the output." | |
| ) | |
| return tool_output | |
| # Chain with output check | |
| code_chain_raw = ( | |
| code_gen_prompt | structured_llm | check_llm_output | |
| ) | |
| def insert_errors(inputs): | |
| """Insert errors for tool parsing in the messages""" | |
| # Get errors | |
| error = inputs["error"] | |
| messages = inputs["messages"] | |
| messages += [ | |
| ( | |
| "assistant", | |
| f"Retry. You are required to fix the parsing errors: {error} \n\n You must invoke the provided tool.", | |
| ) | |
| ] | |
| return { | |
| "messages": messages, | |
| "context": inputs["context"], | |
| } | |
| # This will be run as a fallback chain | |
| fallback_chain = insert_errors | code_chain_raw | |
| N = 3 # Max re-tries | |
| code_gen_chain_re_try = code_chain_raw.with_fallbacks( | |
| fallbacks=[fallback_chain] * N, exception_key="error" | |
| ) | |
| def parse_output(solution): | |
| """When we add 'include_raw=True' to structured output, | |
| it will return a dict w 'raw', 'parsed', 'parsing_error'.""" | |
| return solution["parsed"] | |
| # Optional: With re-try to correct for failure to invoke tool | |
| code_gen_chain = code_gen_chain_re_try | parse_output | |
| # No re-try | |
| # code_gen_chain = code_gen_prompt | structured_llm | parse_output | |
| ### Create State | |
| class GraphState(TypedDict): | |
| """ | |
| Represents the state of our graph. | |
| Attributes: | |
| error : Binary flag for control flow to indicate whether test error was tripped | |
| messages : With user question, error messages, reasoning | |
| generation : Code solution | |
| iterations : Number of tries | |
| """ | |
| error: str | |
| messages: List | |
| generation: List | |
| iterations: int | |
| ### HELPER FUNCTIONS | |
| def check_node_typescript_installation(): | |
| """Check if Node.js and TypeScript are properly installed""" | |
| try: | |
| # Check Node.js | |
| node_version = subprocess.run(["node", "--version"], | |
| capture_output=True, | |
| text=True) | |
| if node_version.returncode != 0: | |
| return False, "Node.js is not installed or not in PATH" | |
| # Check TypeScript | |
| tsc_version = subprocess.run(["npx", "tsc", "--version"], | |
| capture_output=True, | |
| text=True) | |
| if tsc_version.returncode != 0: | |
| return False, "TypeScript is not installed. Please run 'npm install -g typescript'" | |
| return True, "Environment OK" | |
| except Exception as e: | |
| return False, f"Error checking environment: {str(e)}" | |
| def create_temp_package_json(): | |
| """Create a temporary package.json file for Node.js execution""" | |
| package_json = { | |
| "name": "temp-code-execution", | |
| "version": "1.0.0", | |
| "type": "module", | |
| "dependencies": { | |
| "typescript": "^4.9.5" | |
| } | |
| } | |
| with open("package.json", "w") as f: | |
| json.dump(package_json, f) | |
| def run_javascript_code(code, is_typescript=False): | |
| """Execute JavaScript or TypeScript code using Node.js""" | |
| # Check environment first | |
| env_ok, env_message = check_node_typescript_installation() | |
| if not env_ok: | |
| return f"Environment Error: {env_message}" | |
| try: | |
| # Create necessary files | |
| create_temp_package_json() | |
| if is_typescript: | |
| # For TypeScript, we need to compile first | |
| with open("temp_code.ts", "w") as f: | |
| f.write(code) | |
| # Compile TypeScript | |
| compile_process = subprocess.run( | |
| ["npx", "tsc", "temp_code.ts", "--module", "ES2020", "--target", "ES2020"], | |
| capture_output=True, | |
| text=True | |
| ) | |
| # if compile_process.returncode != 0: | |
| # return f"TypeScript Compilation Error:\n{compile_process.stderr}" | |
| return compile_process | |
| # Run compiled JavaScript | |
| file_to_run = "temp_code.js" | |
| else: | |
| # For JavaScript, write directly to .js file | |
| with open("temp_code.js", "w") as f: | |
| f.write(code) | |
| file_to_run = "temp_code.js" | |
| # Execute the code using Node.js | |
| result = subprocess.run( | |
| ["node", file_to_run], | |
| capture_output=True, | |
| text=True | |
| ) | |
| # Clean up temporary files | |
| cleanup_files = ["temp_code.js", "temp_code.ts", "package.json"] | |
| for file in cleanup_files: | |
| if os.path.exists(file): | |
| os.remove(file) | |
| # return result.stderr if result.stderr else result.stdout | |
| return result | |
| except Exception as e: | |
| return f"Error: {e}" | |
| def run_rust_code(code): | |
| with open('code.rs', 'w') as file: | |
| file.write(code) | |
| compile_process = subprocess.Popen(['rustc', 'code.rs'], | |
| stdout=subprocess.PIPE, | |
| stderr=subprocess.PIPE, | |
| text=True) | |
| compile_output, compile_errors = compile_process.communicate() | |
| if compile_process.returncode != 0: | |
| return f"Compilation Error: {compile_errors}" | |
| run_process = subprocess.Popen(['./code'], | |
| stdout=subprocess.PIPE, | |
| stderr=subprocess.PIPE, | |
| text=True) | |
| run_output, run_errors = run_process.communicate() | |
| return run_output if not run_errors else run_errors | |
| ### Parameter | |
| # Max tries | |
| max_iterations = 3 | |
| # Reflect | |
| # flag = 'reflect' | |
| flag = "do not reflect" | |
| ### Nodes | |
| def generate(state: GraphState): | |
| """ | |
| Generate a code solution | |
| Args: | |
| state (dict): The current graph state | |
| Returns: | |
| state (dict): New key added to state, generation | |
| """ | |
| print("---GENERATING CODE SOLUTION---") | |
| # State | |
| messages = state["messages"] | |
| iterations = state["iterations"] | |
| error = state["error"] | |
| question = state['messages'][-1][1] | |
| # We have been routed back to generation with an error | |
| if error == "yes": | |
| messages += [ | |
| ( | |
| "user", | |
| "Now, try again. Invoke the code tool to structure the output with a prefix, imports, and code block:", | |
| ) | |
| ] | |
| # Post-processing | |
| def format_docs(docs): | |
| return "\n\n".join(doc.page_content for doc in docs) | |
| retrieved_docs = retriever.invoke(question) | |
| formated_docs = format_docs(retrieved_docs) | |
| # Solution | |
| code_solution = code_gen_chain.invoke( | |
| {"context": formated_docs, "messages": messages} | |
| ) | |
| messages += [ | |
| ( | |
| "assistant", | |
| f"{code_solution.prefix} \n Imports: {code_solution.imports} \n Code: {code_solution.code}", | |
| ) | |
| ] | |
| # Increment | |
| iterations = iterations + 1 | |
| return {"generation": code_solution, "messages": messages, "iterations": iterations} | |
| def code_check(state: GraphState): | |
| """ | |
| Check code | |
| Args: | |
| state (dict): The current graph state | |
| Returns: | |
| state (dict): New key added to state, error | |
| """ | |
| print("---CHECKING CODE---") | |
| # State | |
| messages = state["messages"] | |
| code_solution = state["generation"] | |
| iterations = state["iterations"] | |
| # Get solution components | |
| imports = code_solution.imports | |
| code = code_solution.code | |
| language = code_solution.language | |
| if language.lower()=="python": | |
| # Check imports | |
| try: | |
| exec(imports) | |
| except Exception as e: | |
| print("---CODE IMPORT CHECK: FAILED---") | |
| error_message = [("user", f"Your solution failed the import test: {e}")] | |
| messages += error_message | |
| return { | |
| "generation": code_solution, | |
| "messages": messages, | |
| "iterations": iterations, | |
| "error": "yes", | |
| } | |
| # Check execution | |
| try: | |
| exec(imports + "\n" + code) | |
| except Exception as e: | |
| print("---CODE BLOCK CHECK: FAILED---") | |
| error_message = [("user", f"Your solution failed the code execution test: {e}")] | |
| messages += error_message | |
| return { | |
| "generation": code_solution, | |
| "messages": messages, | |
| "iterations": iterations, | |
| "error": "yes", | |
| } | |
| if language.lower()=="javascript": | |
| full_code = imports + "\n" + code | |
| result = run_javascript_code(full_code, is_typescript=False) | |
| if result.stderr: | |
| print("---JS CODE BLOCK CHECK: FAILED---") | |
| print(f"This is the error:{result.stderr}") | |
| error_message = [("user", f"Your javascript solution failed the code execution test: {result.stderr}")] | |
| messages += error_message | |
| return { | |
| "generation": code_solution, | |
| "messages": messages, | |
| "iterations": iterations, | |
| "error": "yes", | |
| } | |
| if language.lower()=="typescript": | |
| full_code = imports + "\n" + code | |
| result = run_javascript_code(full_code, is_typescript=True) | |
| if result.stderr: | |
| print("---TS CODE BLOCK CHECK: FAILED---") | |
| print(f"This is the error:{result.stderr}") | |
| error_message = [("user", f"Your typesript solution failed the code execution test: {result.stderr}")] | |
| messages += error_message | |
| return { | |
| "generation": code_solution, | |
| "messages": messages, | |
| "iterations": iterations, | |
| "error": "yes", | |
| } | |
| if language.lower()=="rust": | |
| full_code = imports + "\n" + code | |
| with open('code.rs', 'w') as file: | |
| file.write(full_code) | |
| compile_process = subprocess.Popen(['rustc', 'code.rs'], | |
| stdout=subprocess.PIPE, | |
| stderr=subprocess.PIPE, | |
| text=True) | |
| compile_output, compile_errors = compile_process.communicate() | |
| if compile_process.stderr: | |
| # return f"Compilation Error: {compile_errors}" | |
| print("---RUST CODE BLOCK CHECK: COMPILATION FAILED---") | |
| print(f"This is the error:{compile_process.stderr}") | |
| error_message = [("user", f"Your rust solution failed the code compilation test: {compile_process.stderr}")] | |
| messages += error_message | |
| return { | |
| "generation": code_solution, | |
| "messages": messages, | |
| "iterations": iterations, | |
| "error": "yes", | |
| } | |
| run_process = subprocess.Popen(['./code'], | |
| stdout=subprocess.PIPE, | |
| stderr=subprocess.PIPE, | |
| text=True) | |
| run_output, run_errors = run_process.communicate() | |
| if run_process.stderr: | |
| print("---RUST CODE BLOCK CHECK: RUN FAILED---") | |
| print(f"This is the error:{run_errors}") | |
| error_message = [("user", f"Your rust solution failed the code run test: {run_errors}")] | |
| messages += error_message | |
| return { | |
| "generation": code_solution, | |
| "messages": messages, | |
| "iterations": iterations, | |
| "error": "yes", | |
| } | |
| # return run_output if not run_errors else run_errors | |
| elif language.lower() not in ["rust", "python", "typescript", "javascript"]: | |
| # Can't test the code | |
| print("---CANNOT TEST CODE: CODE NOT IN EXPECTED LANGUAGE---") | |
| return { | |
| "generation": code_solution, | |
| "messages": messages, | |
| "iterations": iterations, | |
| "error": "no", | |
| } | |
| # No errors | |
| print("---NO CODE TEST FAILURES---") | |
| return { | |
| "generation": code_solution, | |
| "messages": messages, | |
| "iterations": iterations, | |
| "error": "no", | |
| } | |
| def reflect(state: GraphState): | |
| """ | |
| Reflect on errors | |
| Args: | |
| state (dict): The current graph state | |
| Returns: | |
| state (dict): New key added to state, generation | |
| """ | |
| print("---REFLECTING ON CODE SOLUTION ERRORS---") | |
| # State | |
| messages = state["messages"] | |
| iterations = state["iterations"] | |
| code_solution = state["generation"] | |
| question = state['messages'][-1][1] | |
| # Prompt reflection | |
| # Post-processing | |
| def format_docs(docs): | |
| return "\n\n".join(doc.page_content for doc in docs) | |
| retrieved_docs = retriever.invoke(question) | |
| formated_docs = format_docs(retrieved_docs) | |
| # Add reflection | |
| reflections = code_gen_chain.invoke( | |
| {"context": formated_docs, "messages": messages} | |
| ) | |
| messages += [("assistant", f"Here are reflections on the error: {reflections}")] | |
| return {"generation": code_solution, "messages": messages, "iterations": iterations} | |
| ### Edges | |
| def decide_to_finish(state: GraphState): | |
| """ | |
| Determines whether to finish. | |
| Args: | |
| state (dict): The current graph state | |
| Returns: | |
| str: Next node to call | |
| """ | |
| error = state["error"] | |
| iterations = state["iterations"] | |
| if error == "no" or iterations == max_iterations: | |
| print("---DECISION: FINISH---") | |
| return "end" | |
| else: | |
| print("---DECISION: RE-TRY SOLUTION---") | |
| if flag == "reflect": | |
| return "reflect" | |
| else: | |
| return "generate" | |
| def get_runnable(): | |
| workflow = StateGraph(GraphState) | |
| # Define the nodes | |
| workflow.add_node("generate", generate) # generation solution | |
| # workflow.add_node("check_code", code_check) # check code | |
| # workflow.add_node("reflect", reflect) # reflect | |
| # Build graph | |
| workflow.add_edge(START, "generate") | |
| workflow.add_edge("generate", END) | |
| # workflow.add_edge("generate", "check_code") | |
| # workflow.add_conditional_edges( | |
| # "check_code", | |
| # decide_to_finish, | |
| # { | |
| # "end": END, | |
| # "reflect": "reflect", | |
| # "generate": "generate", | |
| # }, | |
| # ) | |
| # workflow.add_edge("reflect", "generate") | |
| # Remove the checkpointer for now since it's causing issues | |
| code_assistant_app = workflow.compile() | |
| # memory = AsyncSqliteSaver.from_conn_string(":memory:") | |
| # code_assistant_app = workflow.compile(checkpointer=memory) | |
| return code_assistant_app | |
| # if __name__ == "__main__": | |
| # graph = get_runnable() | |
| # prompt = "How do I read from the solana network?" | |
| # print(f'{graph.invoke({"messages": [("user", prompt)], "iterations": 0, "error": ""})}') | |