| import os |
| import huggingface_hub |
| import streamlit as st |
| from config import config |
| from utils import get_assistant_message |
| from functioncall import ModelInference |
| from prompter import PromptManager |
|
|
|
|
| @st.cache_resource(show_spinner="Loading model..") |
| def init_llm(): |
| huggingface_hub.login(token=config.hf_token, new_session=False) |
| llm = ModelInference(chat_template=config.chat_template) |
| return llm |
|
|
| def get_response(prompt): |
| try: |
| return llm.generate_function_call( |
| prompt, |
| config.chat_template, |
| config.num_fewshot, |
| config.max_depth |
| ) |
| except Exception as e: |
| return f"An error occurred: {str(e)}" |
| |
| def get_output(context, user_input): |
| try: |
| prompt_schema = llm.prompter.read_yaml_file("prompt_assets/output_sys_prompt.yml") |
| sys_prompt = llm.prompter.format_yaml_prompt(prompt_schema, dict()) + \ |
| f"Information:\n{context}" |
| convo = [ |
| {"role": "system", "content": sys_prompt}, |
| {"role": "user", "content": user_input}, |
| ] |
| response = llm.run_inference(convo) |
| return get_assistant_message(response, config.chat_template, llm.tokenizer.eos_token) |
| except Exception as e: |
| return f"An error occurred: {str(e)}" |
|
|
| def main(): |
| st.title("LLM-ADE 9B Demo") |
| |
| input_text = st.text_area("Enter your text here:", value="", height=200) |
| |
| if st.button("Generate"): |
| if input_text: |
| with st.spinner('Generating response...'): |
| agent_resp = get_response(input_text) |
| st.write(get_output(agent_resp, input_text)) |
| else: |
| st.warning("Please enter some text to generate a response.") |
|
|
| llm = init_llm() |
|
|
| def main_headless(): |
| while True: |
| input_text = input("Enter your text here: ") |
| agent_resp = get_response(input_text) |
| print('\033[94m' + get_output(agent_resp, input_text) + '\033[0m') |
|
|
| if __name__ == "__main__": |
| if config.headless: |
| main_headless() |
| else: |
| main() |
|
|