| | import os |
| | import huggingface_hub |
| | import streamlit as st |
| | from config import config |
| | from utils import get_assistant_message |
| | from functioncall import ModelInference |
| | from prompter import PromptManager |
| |
|
| | print("Why, hello there!", flush=True) |
| |
|
| | @st.cache_resource(show_spinner="Loading model..") |
| | def init_llm(): |
| | huggingface_hub.login(token=config.hf_token, new_session=False) |
| | llm = ModelInference(chat_template=config.chat_template) |
| | return llm |
| |
|
| | def get_response(prompt): |
| | try: |
| | return llm.generate_function_call( |
| | prompt, |
| | config.chat_template, |
| | config.num_fewshot, |
| | config.max_depth |
| | ) |
| | except Exception as e: |
| | return f"An error occurred: {str(e)}" |
| | |
| | def get_output(context, user_input): |
| | try: |
| | config.status.update(label=":bulb: Preparing answer..") |
| | prompt_schema = llm.prompter.read_yaml_file("prompt_assets/output_sys_prompt.yml") |
| | sys_prompt = llm.prompter.format_yaml_prompt(prompt_schema, dict()) + \ |
| | f"Information:\n{context}" |
| | convo = [ |
| | {"role": "system", "content": sys_prompt}, |
| | {"role": "user", "content": user_input}, |
| | ] |
| | response = llm.run_inference(convo) |
| | return get_assistant_message(response, config.chat_template, llm.tokenizer.eos_token) |
| | except Exception as e: |
| | return f"An error occurred: {str(e)}" |
| |
|
| | def main(): |
| | st.title("LLM-ADE 9B Demo") |
| | |
| | input_text = st.text_area("Enter your text here:", value="", height=200) |
| | |
| | if st.button("Generate"): |
| | if input_text: |
| | with st.status('Generating response...') as status: |
| | config.status = status |
| | agent_resp = get_response(input_text) |
| | st.write(get_output(agent_resp, input_text)) |
| | config.status.update(label="Finished!", state="complete", expanded=True) |
| | else: |
| | st.warning("Please enter some text to generate a response.") |
| |
|
| | llm = init_llm() |
| |
|
| | def main_headless(): |
| | while True: |
| | input_text = input("Enter your text here: ") |
| | agent_resp = get_response(input_text) |
| | print('\033[94m' + get_output(agent_resp, input_text) + '\033[0m') |
| |
|
| | if __name__ == "__main__": |
| | print(f"Test env vars: {os.getenv('TEST_SECRET')}") |
| | if config.headless: |
| | main_headless() |
| | else: |
| | main() |
| |
|