| | import os |
| | from time import time |
| | import huggingface_hub |
| | import streamlit as st |
| | from config import config |
| | from functioncall import ModelInference |
| |
|
| |
|
| | @st.cache_resource(show_spinner="Loading model..") |
| | def init_llm(): |
| | huggingface_hub.login(token=config.hf_token, new_session=False) |
| | llm = ModelInference(chat_template=config.chat_template) |
| | return llm |
| |
|
| |
|
| | def function_agent(prompt): |
| | try: |
| | return llm.generate_function_call( |
| | prompt, config.chat_template, config.num_fewshot, config.max_depth |
| | ) |
| | except Exception as e: |
| | return f"An error occurred: {str(e)}" |
| |
|
| |
|
| | def output_agent(context, user_input): |
| | """Takes the output of the RAG and generates a final response.""" |
| | try: |
| | config.status.update(label=":bulb: Preparing answer..") |
| | script_dir = os.path.dirname(os.path.abspath(__file__)) |
| | prompt_path = os.path.join(script_dir, "prompt_assets", "output_sys_prompt.yml") |
| | prompt_schema = llm.prompter.read_yaml_file(prompt_path) |
| | sys_prompt = ( |
| | llm.prompter.format_yaml_prompt(prompt_schema, dict()) |
| | + f"Information:\n{context}" |
| | ) |
| | convo = [ |
| | {"role": "system", "content": sys_prompt}, |
| | {"role": "user", "content": user_input}, |
| | ] |
| | response = llm.run_inference(convo) |
| | return response |
| | except Exception as e: |
| | return f"An error occurred: {str(e)}" |
| |
|
| | def query_agent(prompt): |
| | """Modifies the prompt and runs inference on it.""" |
| | try: |
| | config.status.update(label=":brain: Starting inference..") |
| | script_dir = os.path.dirname(os.path.abspath(__file__)) |
| | prompt_path = os.path.join(script_dir, "prompt_assets", "output_sys_prompt.yml") |
| | prompt_schema = llm.prompter.read_yaml_file(prompt_path) |
| | sys_prompt = llm.prompter.format_yaml_prompt(prompt_schema, dict()) |
| | convo = [ |
| | {"role": "system", "content": sys_prompt}, |
| | {"role": "user", "content": prompt}, |
| | ] |
| | response = llm.run_inference(convo) |
| | return response |
| | except Exception as e: |
| | return f"An error occurred: {str(e)}" |
| |
|
| |
|
| | def get_response(input_text: str): |
| | """This is the main function that generates the final response.""" |
| | agent_resp = function_agent(input_text) |
| | output = output_agent(agent_resp, input_text) |
| | return output |
| |
|
| |
|
| | def main(): |
| | st.title("LLM-ADE 9B Demo") |
| |
|
| | input_text = st.text_area("Enter your text here:", value="", height=200) |
| |
|
| | if st.button("Generate"): |
| | if input_text: |
| | with st.status("Generating response...") as status: |
| | config.status = status |
| | st.write(get_response(input_text)) |
| | config.status.update(label="Finished!", state="complete", expanded=True) |
| | else: |
| | st.warning("Please enter some text to generate a response.") |
| |
|
| |
|
| | def main_headless(prompt: str): |
| | start = time() |
| | print("\033[94m" + get_response(prompt) + "\033[0m") |
| | print(f"Time taken: {time() - start:.2f}s\n" + "-" * 20) |
| |
|
| |
|
| | llm = init_llm() |
| |
|
| |
|
| | if __name__ == "__main__": |
| | if config.headless: |
| | import fire |
| |
|
| | fire.Fire(main_headless) |
| | else: |
| | main() |
| |
|