Spaces:
Runtime error
Runtime error
| import openai | |
| import streamlit as st | |
| from streamlit_chat import message | |
| from langchain_core.messages import SystemMessage | |
| from langchain_openai import ChatOpenAI | |
| from langchain_core.pydantic_v1 import BaseModel, Field | |
| from langgraph.graph import MessageGraph, END | |
| from langgraph.checkpoint.sqlite import SqliteSaver | |
| from langchain_core.messages import HumanMessage | |
| from typing import List | |
| import os | |
| import uuid | |
| template = """Your job is to get information from a user about their profession. We are aiming to generate a profile later | |
| You should get the following information from them: | |
| - Job | |
| - Company | |
| - tools for example for a software engineer(which frameworks/languages) | |
| If you are not able to discerne this info, ask them to clarify! Do not attempt to wildly guess. | |
| If you're asking anything please be friendly and comment on any of the info you have found e.g working at x company must have been a thrilling challenge | |
| Ask one question at a time | |
| After you are able to discerne all the information, call the relevant tool""" | |
| OPENAI_API_KEY='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' | |
| os.environ['OPENAI_API_KEY'] = OPENAI_API_KEY | |
| llm = ChatOpenAI(temperature=0) | |
| def get_messages_info(messages): | |
| return [SystemMessage(content=template)] + messages | |
| class PromptInstructions(BaseModel): | |
| """Instructions on how to prompt the LLM.""" | |
| job: str | |
| company: str | |
| technologies: List[str] | |
| hobies: List[str] | |
| llm_with_tool = llm.bind_tools([PromptInstructions]) | |
| chain = get_messages_info | llm_with_tool | |
| # Helper function for determining if tool was called | |
| def _is_tool_call(msg): | |
| return hasattr(msg, "additional_kwargs") and 'tool_calls' in msg.additional_kwargs | |
| # New system prompt | |
| prompt_system = """Based on the following context, write a good professional profile. Infer the soft skills: | |
| {reqs}""" | |
| # Function to get the messages for the profile | |
| # Will only get messages AFTER the tool call | |
| def get_profile_messages(messages): | |
| tool_call = None | |
| other_msgs = [] | |
| for m in messages: | |
| if _is_tool_call(m): | |
| tool_call = m.additional_kwargs['tool_calls'][0]['function']['arguments'] | |
| elif tool_call is not None: | |
| other_msgs.append(m) | |
| return [SystemMessage(content=prompt_system.format(reqs=tool_call))] + other_msgs | |
| profile_gen_chain = get_profile_messages | llm | |
| def get_state(messages): | |
| if _is_tool_call(messages[-1]): | |
| return "profile" | |
| elif not isinstance(messages[-1], HumanMessage): | |
| return END | |
| for m in messages: | |
| if _is_tool_call(m): | |
| return "profile" | |
| return "info" | |
| def get_graph(): | |
| memory = SqliteSaver.from_conn_string(":memory:") | |
| nodes = {k:k for k in ['info', 'profile', END]} | |
| workflow = MessageGraph() | |
| workflow.add_node("info", chain) | |
| workflow.add_node("profile", profile_gen_chain) | |
| workflow.add_conditional_edges("info", get_state, nodes) | |
| workflow.add_conditional_edges("profile", get_state, nodes) | |
| workflow.set_entry_point("info") | |
| graph = workflow.compile(checkpointer=memory) | |
| return graph | |
| graph = get_graph() | |
| config = {"configurable": {"thread_id": str(uuid.uuid4())}} | |
| # Streamlit app layout | |
| st.title("JobEasy AI") | |
| clear_button = st.sidebar.button("Clear Conversation", key="clear") | |
| # Initialise session state variables | |
| if 'generated' not in st.session_state: | |
| st.session_state['generated'] = ['Please tell me about your most recent career'] | |
| if 'past' not in st.session_state: | |
| st.session_state['past'] = [] | |
| if 'messages' not in st.session_state: | |
| st.session_state['messages'] = [] | |
| # reset everything | |
| if clear_button: | |
| st.session_state['generated'] = ['Please tell me about your most recent career'] | |
| st.session_state['past'] = [] | |
| st.session_state['messages'] = [] | |
| # container for chat history | |
| response_container = st.container() | |
| # container for text box | |
| container = st.container() | |
| def query(payload): | |
| for output in graph.stream([HumanMessage(content=payload)], config=config): | |
| if "__end__" in output: | |
| continue | |
| # stream() yields dictionaries with output keyed by node name | |
| for key, value in output.items(): | |
| st.session_state['messages'].append({"role": "assistant", "content": value.content}) | |
| st.session_state['past'].append(user_input) | |
| st.session_state['generated'].append(value.content) | |
| with container: | |
| with st.form(key='my_form', clear_on_submit=True): | |
| user_input = st.text_area("You:", key='input', height=100) | |
| submit_button = st.form_submit_button(label='Send') | |
| if submit_button and user_input: | |
| query(user_input) | |
| if st.session_state['generated']: | |
| with response_container: | |
| for i in range(len(st.session_state['generated'])): | |
| message(st.session_state["generated"][i], key=str(i)) | |
| if len(st.session_state["past"]) > 0 and i < len(st.session_state["past"]): | |
| message(st.session_state["past"][i], is_user=True, key=str(i) + '_user') | |