|
|
import os |
|
|
from time import time |
|
|
import huggingface_hub |
|
|
import streamlit as st |
|
|
from config import config |
|
|
from functioncall import ModelInference |
|
|
|
|
|
|
|
|
@st.cache_resource(show_spinner="Loading model..") |
|
|
def init_llm(): |
|
|
huggingface_hub.login(token=config.hf_token, new_session=False) |
|
|
llm = ModelInference(chat_template=config.chat_template) |
|
|
return llm |
|
|
|
|
|
|
|
|
def get_response(prompt): |
|
|
try: |
|
|
return llm.generate_function_call( |
|
|
prompt, config.chat_template, config.num_fewshot, config.max_depth |
|
|
) |
|
|
except Exception as e: |
|
|
return f"An error occurred: {str(e)}" |
|
|
|
|
|
|
|
|
def get_output(context, user_input): |
|
|
try: |
|
|
config.status.update(label=":bulb: Preparing answer..") |
|
|
script_dir = os.path.dirname(os.path.abspath(__file__)) |
|
|
prompt_path = os.path.join(script_dir, 'prompt_assets', 'output_sys_prompt.yml') |
|
|
prompt_schema = llm.prompter.read_yaml_file(prompt_path) |
|
|
sys_prompt = ( |
|
|
llm.prompter.format_yaml_prompt(prompt_schema, dict()) |
|
|
+ f"Information:\n{context}" |
|
|
) |
|
|
convo = [ |
|
|
{"role": "system", "content": sys_prompt}, |
|
|
{"role": "user", "content": user_input}, |
|
|
] |
|
|
response = llm.run_inference(convo) |
|
|
return response |
|
|
except Exception as e: |
|
|
return f"An error occurred: {str(e)}" |
|
|
|
|
|
|
|
|
def main(): |
|
|
st.title("LLM-ADE 9B Demo") |
|
|
|
|
|
input_text = st.text_area("Enter your text here:", value="", height=200) |
|
|
|
|
|
if st.button("Generate"): |
|
|
if input_text: |
|
|
with st.status("Generating response...") as status: |
|
|
config.status = status |
|
|
agent_resp = get_response(input_text) |
|
|
st.write(get_output(agent_resp, input_text)) |
|
|
config.status.update(label="Finished!", state="complete", expanded=True) |
|
|
else: |
|
|
st.warning("Please enter some text to generate a response.") |
|
|
|
|
|
|
|
|
llm = init_llm() |
|
|
|
|
|
|
|
|
def main_headless(prompt: str): |
|
|
start = time() |
|
|
agent_resp = get_response(prompt) |
|
|
print("\033[94m" + get_output(agent_resp, prompt) + "\033[0m") |
|
|
print(f"Time taken: {time() - start:.2f}s\n" + "-" * 20) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
if config.headless: |
|
|
import fire |
|
|
fire.Fire(main_headless) |
|
|
else: |
|
|
main() |
|
|
|