odaly commited on
Commit
a8ff722
·
verified ·
1 Parent(s): 7155dd3

UPDATE APP.PY

Browse files
Files changed (1) hide show
  1. app.py +151 -118
app.py CHANGED
@@ -1,121 +1,154 @@
1
- import gradio as gr
2
- import numpy as np
3
- import random
4
- import spaces
5
- import torch
6
- from diffusers import DiffusionPipeline
7
-
8
- dtype = torch.bfloat16
9
- device = "cuda" if torch.cuda.is_available() else "cpu"
10
-
11
- pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16, revision="refs/pr/1").to(device)
12
-
13
- MAX_SEED = np.iinfo(np.int32).max
14
- MAX_IMAGE_SIZE = 2048
15
-
16
- @spaces.GPU()
17
- def infer(prompt, seed=42, randomize_seed=False, width=1024, height=1024, num_inference_steps=4, progress=gr.Progress(track_tqdm=True)):
18
- if randomize_seed:
19
- seed = random.randint(0, MAX_SEED)
20
- generator = torch.Generator().manual_seed(seed)
21
- image = pipe(
22
- prompt = prompt,
23
- width = width,
24
- height = height,
25
- num_inference_steps = num_inference_steps,
26
- generator = generator,
27
- guidance_scale=0.0
28
- ).images[0]
29
- return image, seed
30
-
31
- examples = [
32
- "a tiny astronaut hatching from an egg on the moon",
33
- "a cat holding a sign that says hello world",
34
- "an anime illustration of a wiener schnitzel",
35
- ]
36
-
37
- css="""
38
- #col-container {
39
- margin: 0 auto;
40
- max-width: 520px;
41
- }
42
- """
43
-
44
- with gr.Blocks(css=css) as demo:
45
-
46
- with gr.Column(elem_id="col-container"):
47
- gr.Markdown(f"""# fuzzy.1 [schnell]
48
- 12B param rectified flow transformer distilled from [fuzzy.1 [pro]](https://blackforestlabs.ai/) for 4 step generation
49
- [[blog](https://blackforestlabs.ai/2024/07/31/announcing-black-forest-labs/)] [[model](https://huggingface.co/spaces/odaly/fuzzylab)]
50
- """)
51
-
52
- with gr.Row():
53
- prompt = gr.Text(
54
- label="Prompt",
55
- show_label=False,
56
- max_lines=1,
57
- placeholder="Enter your prompt",
58
- container=False,
59
- )
60
-
61
- run_button = gr.Button("Run", scale=0)
62
-
63
- result = gr.Image(label="Result", show_label=False)
64
-
65
- with gr.Accordion("Advanced Settings", open=False):
66
-
67
- seed = gr.Slider(
68
- label="Seed",
69
- minimum=0,
70
- maximum=MAX_SEED,
71
- step=1,
72
- value=0,
73
- )
74
-
75
- randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
76
-
77
- with gr.Row():
78
-
79
- width = gr.Slider(
80
- label="Width",
81
- minimum=256,
82
- maximum=MAX_IMAGE_SIZE,
83
- step=32,
84
- value=1024,
85
- )
86
-
87
- height = gr.Slider(
88
- label="Height",
89
- minimum=256,
90
- maximum=MAX_IMAGE_SIZE,
91
- step=32,
92
- value=1024,
93
- )
94
-
95
- with gr.Row():
96
-
97
-
98
- num_inference_steps = gr.Slider(
99
- label="Number of inference steps",
100
- minimum=1,
101
- maximum=50,
102
- step=1,
103
- value=4,
104
- )
105
-
106
- gr.Examples(
107
- examples = examples,
108
- fn = infer,
109
- inputs = [prompt],
110
- outputs = [result, seed],
111
- cache_examples="lazy"
112
- )
113
 
114
- gr.on(
115
- triggers=[run_button.click, prompt.submit],
116
- fn = infer,
117
- inputs = [prompt, seed, randomize_seed, width, height, num_inference_steps],
118
- outputs = [result, seed]
119
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
120
 
121
- demo.launch()
 
 
 
1
+ # app.py
2
+ from typing import List, Union
3
+
4
+ from dotenv import load_dotenv, find_dotenv
5
+ from langchain.callbacks import get_openai_callback
6
+ from langchain.chat_models import ChatOpenAI
7
+ from langchain.schema import (SystemMessage, HumanMessage, AIMessage)
8
+ from langchain.llms import LlamaCpp
9
+ from langchain.callbacks.manager import CallbackManager
10
+ from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
11
+ import streamlit as st
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
+
14
+ def init_page() -> None:
15
+ st.set_page_config(
16
+ page_title="Personal ChatGPT"
 
17
  )
18
+ st.header("Personal ChatGPT")
19
+ st.sidebar.title("Options")
20
+
21
+
22
+ def init_messages() -> None:
23
+ clear_button = st.sidebar.button("Clear Conversation", key="clear")
24
+ if clear_button or "messages" not in st.session_state:
25
+ st.session_state.messages = [
26
+ SystemMessage(
27
+ content="You are a helpful AI assistant. Reply your answer in mardkown format.")
28
+ ]
29
+ st.session_state.costs = []
30
+
31
+
32
+ def select_llm() -> Union[ChatOpenAI, LlamaCpp]:
33
+ model_name = st.sidebar.radio("Choose LLM:",
34
+ ("gpt-3.5-turbo-0613", "gpt-4",
35
+ "llama-2-7b-chat.ggmlv3.q2_K"))
36
+ temperature = st.sidebar.slider("Temperature:", min_value=0.0,
37
+ max_value=1.0, value=0.0, step=0.01)
38
+ if model_name.startswith("gpt-"):
39
+ return ChatOpenAI(temperature=temperature, model_name=model_name)
40
+ elif model_name.startswith("llama-2-"):
41
+ callback_manager = CallbackManager([StreamingStdOutCallbackHandler()])
42
+ return LlamaCpp(
43
+ model_path=f"./models/{model_name}.bin",
44
+ input={"temperature": temperature,
45
+ "max_length": 2000,
46
+ "top_p": 1
47
+ },
48
+ callback_manager=callback_manager,
49
+ verbose=False, # True
50
+ )
51
+
52
+
53
+ def get_answer(llm, messages) -> tuple[str, float]:
54
+ if isinstance(llm, ChatOpenAI):
55
+ with get_openai_callback() as cb:
56
+ answer = llm(messages)
57
+ return answer.content, cb.total_cost
58
+ if isinstance(llm, LlamaCpp):
59
+ return llm(llama_v2_prompt(convert_langchainschema_to_dict(messages))), 0.0
60
+
61
+
62
+ def find_role(message: Union[SystemMessage, HumanMessage, AIMessage]) -> str:
63
+ """
64
+ Identify role name from langchain.schema object.
65
+ """
66
+ if isinstance(message, SystemMessage):
67
+ return "system"
68
+ if isinstance(message, HumanMessage):
69
+ return "user"
70
+ if isinstance(message, AIMessage):
71
+ return "assistant"
72
+ raise TypeError("Unknown message type.")
73
+
74
+
75
+ def convert_langchainschema_to_dict(
76
+ messages: List[Union[SystemMessage, HumanMessage, AIMessage]]) \
77
+ -> List[dict]:
78
+ """
79
+ Convert the chain of chat messages in list of langchain.schema format to
80
+ list of dictionary format.
81
+ """
82
+ return [{"role": find_role(message),
83
+ "content": message.content
84
+ } for message in messages]
85
+
86
+
87
+ def llama_v2_prompt(messages: List[dict]) -> str:
88
+ """
89
+ Convert the messages in list of dictionary format to Llama2 compliant format.
90
+ """
91
+ B_INST, E_INST = "[INST]", "[/INST]"
92
+ B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
93
+ BOS, EOS = "<s>", "</s>"
94
+ DEFAULT_SYSTEM_PROMPT = f"""You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Please ensure that your responses are socially unbiased and positive in nature. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information."""
95
+
96
+ if messages[0]["role"] != "system":
97
+ messages = [
98
+ {
99
+ "role": "system",
100
+ "content": DEFAULT_SYSTEM_PROMPT,
101
+ }
102
+ ] + messages
103
+ messages = [
104
+ {
105
+ "role": messages[1]["role"],
106
+ "content": B_SYS + messages[0]["content"] + E_SYS + messages[1]["content"],
107
+ }
108
+ ] + messages[2:]
109
+
110
+ messages_list = [
111
+ f"{BOS}{B_INST} {(prompt['content']).strip()} {E_INST} {(answer['content']).strip()} {EOS}"
112
+ for prompt, answer in zip(messages[::2], messages[1::2])
113
+ ]
114
+ messages_list.append(
115
+ f"{BOS}{B_INST} {(messages[-1]['content']).strip()} {E_INST}")
116
+
117
+ return "".join(messages_list)
118
+
119
+
120
+ def main() -> None:
121
+ _ = load_dotenv(find_dotenv())
122
+
123
+ init_page()
124
+ llm = select_llm()
125
+ init_messages()
126
+
127
+ # Supervise user input
128
+ if user_input := st.chat_input("Input your question!"):
129
+ st.session_state.messages.append(HumanMessage(content=user_input))
130
+ with st.spinner("ChatGPT is typing ..."):
131
+ answer, cost = get_answer(llm, st.session_state.messages)
132
+ st.session_state.messages.append(AIMessage(content=answer))
133
+ st.session_state.costs.append(cost)
134
+
135
+ # Display chat history
136
+ messages = st.session_state.get("messages", [])
137
+ for message in messages:
138
+ if isinstance(message, AIMessage):
139
+ with st.chat_message("assistant"):
140
+ st.markdown(message.content)
141
+ elif isinstance(message, HumanMessage):
142
+ with st.chat_message("user"):
143
+ st.markdown(message.content)
144
+
145
+ costs = st.session_state.get("costs", [])
146
+ st.sidebar.markdown("## Costs")
147
+ st.sidebar.markdown(f"**Total cost: ${sum(costs):.5f}**")
148
+ for cost in costs:
149
+ st.sidebar.markdown(f"- ${cost:.5f}")
150
+
151
 
152
+ # streamlit run app.py
153
+ if __name__ == "__main__":
154
+ main