genaitiwari commited on
Commit
6617b51
·
1 Parent(s): 9f396ec

image generation

Browse files
.gitignore CHANGED
@@ -1,3 +1,4 @@
1
 
2
  *.pyc
3
  /autogen_cache/42
 
 
1
 
2
  *.pyc
3
  /autogen_cache/42
4
+ *.jpeg
app.py CHANGED
@@ -3,6 +3,7 @@ from configfile import Config
3
  from src.hf_autogen.hfautogen import hf_llmconfig
4
  from src.streamlitui.loadui import LoadStreamlitUI
5
  from src.usecases.textgen import TexGeneration
 
6
 
7
 
8
 
@@ -37,16 +38,26 @@ if __name__ == "__main__":
37
  st.write(problem)
38
 
39
 
40
- obj_txtgen = TexGeneration(assistant_name="Assistant", user_proxy_name='Userproxy',
41
  llm_config=llm_config,
42
  problem=problem)
43
- obj_txtgen.run()
44
 
45
  elif user_input['selected_usecase'] == "Image Generation":
46
  st.subheader("Image generation")
 
47
  if problem:
48
  with st.chat_message("user"):
49
  st.write(problem)
 
 
 
 
 
 
 
 
 
50
 
51
 
52
 
 
3
  from src.hf_autogen.hfautogen import hf_llmconfig
4
  from src.streamlitui.loadui import LoadStreamlitUI
5
  from src.usecases.textgen import TexGeneration
6
+ from src.usecases.imggen import ImageGeneration
7
 
8
 
9
 
 
38
  st.write(problem)
39
 
40
 
41
+ obj_txt_gen = TexGeneration(assistant_name="Assistant", user_proxy_name='Userproxy',
42
  llm_config=llm_config,
43
  problem=problem)
44
+ obj_txt_gen.run()
45
 
46
  elif user_input['selected_usecase'] == "Image Generation":
47
  st.subheader("Image generation")
48
+
49
  if problem:
50
  with st.chat_message("user"):
51
  st.write(problem)
52
+
53
+
54
+ obj_img_gen = ImageGeneration(assistant_name="Image_Assistant", user_proxy_name='Userproxy',
55
+ llm_config=llm_config,
56
+ problem=problem)
57
+ obj_img_gen.run()
58
+
59
+ # with st.chat_message('ai'):
60
+ # st.image(image.open('./imagegen/response.jpeg'))
61
 
62
 
63
 
src/agents/assistantagent.py CHANGED
@@ -1,5 +1,7 @@
1
  from autogen import AssistantAgent
2
  import streamlit as st
 
 
3
 
4
 
5
  class TrackableAssistantAgent(AssistantAgent):
@@ -7,8 +9,13 @@ class TrackableAssistantAgent(AssistantAgent):
7
  if message and type(message)== str and sender.name =="Userproxy":
8
  with st.chat_message("user"):
9
  st.write(message)
10
-
11
-
12
  return super()._process_received_message(message, sender, silent)
13
 
 
 
 
 
 
 
 
14
 
 
1
  from autogen import AssistantAgent
2
  import streamlit as st
3
+ import base64
4
+ from io import BytesIO
5
 
6
 
7
  class TrackableAssistantAgent(AssistantAgent):
 
9
  if message and type(message)== str and sender.name =="Userproxy":
10
  with st.chat_message("user"):
11
  st.write(message)
 
 
12
  return super()._process_received_message(message, sender, silent)
13
 
14
+ class TrackableImageAssistantAgent(AssistantAgent):
15
+ def _process_received_message(self, message, sender, silent):
16
+ # with st.chat_message('ai'):
17
+ # st.image('./imagegen/response.jpeg')
18
+ return super()._process_received_message(message, sender, silent)
19
+
20
+
21
 
src/agents/userproxyagent.py CHANGED
@@ -1,12 +1,17 @@
1
  from autogen import UserProxyAgent
2
  import streamlit as st
 
 
3
 
4
 
5
  class TrackableUserProxyAgent(UserProxyAgent):
6
  def _process_received_message(self, message, sender, silent):
7
- with st.chat_message(sender.name.lower()):
8
- if type(message)==str:
 
 
 
 
 
9
  st.write(message)
10
- else :
11
- st.write(message['content'])
12
  return super()._process_received_message(message, sender, silent)
 
1
  from autogen import UserProxyAgent
2
  import streamlit as st
3
+ import base64
4
+ from io import BytesIO
5
 
6
 
7
  class TrackableUserProxyAgent(UserProxyAgent):
8
  def _process_received_message(self, message, sender, silent):
9
+
10
+ if type(message)==str and sender.name == 'Image_Assistant':
11
+ with st.chat_message('ai'):
12
+ st.image('./imagegen/response.jpeg')
13
+
14
+ else :
15
+ with st.chat_message('ai'):
16
  st.write(message)
 
 
17
  return super()._process_received_message(message, sender, silent)
src/hf_autogen/hfautogen.py CHANGED
@@ -49,7 +49,7 @@ class APIModelClient:
49
 
50
  input_data = {
51
  "inputs": conversation_history,
52
- "parameters": {"max_new_tokens": 1000, "return_full_text": False, "do_sample": False},
53
  "options": {"wait_for_model": True, "use_cache": False}
54
  # Include any other parameters required by your API
55
  }
 
49
 
50
  input_data = {
51
  "inputs": conversation_history,
52
+ "parameters": {"return_full_text": False, "do_sample": False},
53
  "options": {"wait_for_model": True, "use_cache": False}
54
  # Include any other parameters required by your API
55
  }
src/hf_autogen/imghfautogen.py ADDED
@@ -0,0 +1,226 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import autogen
2
+ from autogen import AssistantAgent, UserProxyAgent, GroupChatManager, GroupChat, ConversableAgent
3
+ from types import SimpleNamespace
4
+ import requests
5
+ import json
6
+ import os
7
+ import shutil
8
+ import random
9
+ import streamlit as st
10
+ import base64
11
+ from io import BytesIO
12
+
13
+ from src.agents.assistantagent import TrackableImageAssistantAgent
14
+ from src.agents.userproxyagent import TrackableUserProxyAgent
15
+
16
+ class APIModelClient:
17
+ def __init__(self, config, **kwargs):
18
+ self.device = config.get("device", "cpu")
19
+ self.api_url = "https://api-inference.huggingface.co/models/stabilityai/stable-diffusion-3.5-large"
20
+ #self.api_url = "https://api-inference.huggingface.co/models/google/gemma-7b-it" # Add the API URL to the config
21
+ self.headers = {"Authorization": "Bearer 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"} # Example: Add any required headers
22
+
23
+ self.model_name = config.get("model")
24
+ self.chat_index = 0
25
+
26
+ self.conversion_mem = ""
27
+
28
+ # self.tokenizer and self.model lines are removed or modified
29
+
30
+ def create(self, params):
31
+ conversation_history = ""
32
+
33
+ for message in params["messages"]:
34
+ prefix = ""
35
+ if message["role"] == "system":
36
+ prefix = f'Bot Description:\n'
37
+ elif message["role"] == "user":
38
+ prefix = f'User____:\n'
39
+ else:
40
+ prefix = f'Agent ({message["role"]}):\n'
41
+ conversation_history += prefix + f'{message["content"]}\n\n'
42
+
43
+
44
+ #try:
45
+ #_input = f'Given the context of the last message: {params["messages"][-2]["content"]}\n\n\nHere is input on the context: {params["messages"][-1]["content"]}'
46
+
47
+
48
+ #except Exception as e:
49
+ # print(e)
50
+ # _input = params["messages"][-1]["content"]
51
+
52
+ input_data = {
53
+ "inputs": conversation_history,
54
+ # "parameters": {"return_full_text": False, "do_sample": False},
55
+ # "options": {"wait_for_model": True, "use_cache": False}
56
+ # Include any other parameters required by your API
57
+ }
58
+
59
+ # Sending the request to your model's API
60
+ response = requests.post(self.api_url, json=input_data, headers=self.headers)
61
+
62
+ if response.status_code == 200:
63
+ return response
64
+ else:
65
+ raise Exception(f"API request failed with status code {response.status_code}: {response.text}")
66
+
67
+
68
+ def message_retrieval(self, response):
69
+ """Retrieve the messages from the response."""
70
+ import io
71
+ from PIL import Image
72
+ image = Image.open(io.BytesIO(response.content))
73
+ image_path = './imagegen/response.jpeg'
74
+ image.save(image_path)
75
+ # Open the image using PIL
76
+ image = Image.open(image_path)
77
+
78
+ # Display the image in Streamlit
79
+ st.image(image, caption="Loaded Image", use_column_width=True)
80
+ return [str(response.content)]
81
+
82
+ def cost(self, response) -> float:
83
+ """Calculate the cost of the response."""
84
+ response.cost = 0
85
+ return 0
86
+
87
+ @staticmethod
88
+ def get_usage(response):
89
+ # returns a dict of prompt_tokens, completion_tokens, total_tokens, cost, model
90
+ # if usage needs to be tracked, else None
91
+ return {}
92
+
93
+
94
+ class APIModelClientWithArguments(APIModelClient):
95
+ def __init__(self, config, hf_key, hf_url="https://api-inference.huggingface.co/models/stabilityai/stable-diffusion-3.5-large", **kwargs):
96
+ self.device = config.get("device", "cpu")
97
+ self.api_url = hf_url
98
+ # self.api_url = "https://api-inference.huggingface.co/models/deepseek-ai/DeepSeek-R1-Distill-Qwen-32B" # Add the API URL to the config
99
+
100
+ self.headers = {"Authorization": f"Bearer {hf_key}"} # Example: Add any required headers
101
+
102
+ self.model_name = config.get("model")
103
+ self.chat_index = 0
104
+
105
+ self.conversion_mem = ""
106
+
107
+
108
+ def hf_llmconfig(selected_model):
109
+ llm_config = {
110
+ "config_list": [{
111
+ "model": selected_model,
112
+ "model_client_cls": "APIModelClientWithArguments",
113
+ "device": ""
114
+ }]
115
+ }
116
+ st.session_state['llm_config'] = llm_config
117
+ return llm_config
118
+ def UserAgent(name, llm_config, max_consecutive_auto_reply=1, code_dir="coding", use_docker=False, system_message="You are a helpful AI assistant"):
119
+ llm_config = {
120
+ "config_list": [{
121
+ "model": "mistralai/Mixtral-8x7B-Instruct-v0.1",
122
+ "model_client_cls": "APIModelClientWithArguments",
123
+ "device": ""
124
+ }]
125
+ }
126
+ user_agent = TrackableUserProxyAgent(
127
+ name=name,
128
+ max_consecutive_auto_reply=max_consecutive_auto_reply,
129
+ llm_config=llm_config,
130
+ is_termination_msg=lambda x: x.get("content", "").rstrip().endswith("TERMINATE"),
131
+ code_execution_config={
132
+ "work_dir": code_dir,
133
+ "use_docker": use_docker,
134
+ },
135
+ system_message=system_message,
136
+ human_input_mode="NEVER"
137
+ )
138
+
139
+ user_agent.register_model_client(model_client_cls=APIModelClientWithArguments, hf_key=st.session_state["api_key"])
140
+
141
+ return user_agent
142
+
143
+ def ModelAgent(name, llm_config, hf_url="https://api-inference.huggingface.co/models/stabilityai/stable-diffusion-3.5-large", system_message="", code_execution=False):
144
+ default_system_message = """You are a helpful AI assistant for generating and manipulating images.
145
+ """
146
+
147
+ if system_message == "":
148
+ system_message = default_system_message
149
+
150
+ # llm_config = {
151
+ # "config_list": [{
152
+ # "model": "",
153
+ # "model_client_cls": "APIModelClientWithArguments",
154
+ # "device": ""
155
+ # }]
156
+ # }
157
+ llm_config =llm_config
158
+
159
+
160
+
161
+ agent = TrackableImageAssistantAgent(
162
+ name=name,
163
+ llm_config=llm_config,
164
+ system_message=system_message,
165
+ code_execution_config=code_execution,
166
+
167
+ )
168
+ agent.register_model_client(model_client_cls=APIModelClientWithArguments, hf_key=st.session_state["api_key"], hf_url=hf_url)
169
+
170
+ return agent
171
+
172
+
173
+ async def InitChat(user, agent, _input, summary_method="reflection_with_llm"):
174
+ def clear_directory_contents(dir_path):
175
+ try:
176
+ for item in os.listdir(dir_path):
177
+ item_path = os.path.join(dir_path, item)
178
+ if os.path.isfile(item_path) or os.path.islink(item_path):
179
+ os.remove(item_path) # Remove files and links
180
+ elif os.path.isdir(item_path):
181
+ shutil.rmtree(item_path) # Remove directories
182
+ shutil.rmtree(dir_path)
183
+ print(f"All contents of '{dir_path}' have been removed.")
184
+ except FileNotFoundError:
185
+ pass
186
+
187
+ #seed = random.randint(0, 99999)
188
+ seed = 42
189
+ #clear_directory_contents(f'./autogen_cache/{seed}')
190
+
191
+ custom_cache = autogen.Cache({"cache_seed": seed, "cache_path_root": "autogen_cache"})
192
+
193
+ await user.a_initiate_chat(
194
+ agent,
195
+ max_turns=1,
196
+ message=_input,
197
+ summary_method=summary_method,
198
+ cache=custom_cache,
199
+
200
+ )
201
+
202
+ #clear_directory_contents(f'./autogen_cache/{seed}')
203
+
204
+ def GroupChat(user, agents, _input, hf_key, hf_url="https://api-inference.huggingface.co/models/stabilityai/stable-diffusion-3.5-large", max_round=5):
205
+ llm_config = {
206
+ "config_list": [{
207
+ "model": "",
208
+ "model_client_cls": "APIModelClientWithArguments",
209
+ "device": ""
210
+ }]
211
+ }
212
+
213
+ groupchat = autogen.GroupChat(agents=agents, messages=[], max_round=max_round, speaker_selection_method="round_robin", allow_repeat_speaker=False)
214
+ manager = autogen.GroupChatManager(groupchat=groupchat, llm_config=llm_config)
215
+
216
+ manager.register_model_client(model_client_cls=APIModelClientWithArguments, hf_key=hf_key, hf_url=hf_url)
217
+ InitChat(user, manager, _input)
218
+
219
+ #Write me a script to save the BTC chart from the past year to an image.
220
+
221
+ # if __name__ == "__main__":
222
+ # print("Running as main")
223
+
224
+
225
+
226
+
src/usecases/imggen.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ from src.hf_autogen.imghfautogen import APIModelClientWithArguments,ModelAgent, UserAgent, InitChat
3
+ from src.agents.assistantagent import TrackableAssistantAgent
4
+ from src.agents.userproxyagent import TrackableUserProxyAgent
5
+ import streamlit as st
6
+
7
+
8
+ class ImageGeneration:
9
+ def __init__(self, assistant_name, user_proxy_name, llm_config, problem):
10
+ # self.assistant = TrackableAssistantAgent(name=assistant_name,
11
+ # system_message="""you are helpful assistant. Reply "TERMINATE" in
12
+ # the end when everything is done """,
13
+ # human_input_mode="NEVER",
14
+ # llm_config=llm_config,
15
+ # )
16
+
17
+ # self.user_proxy = TrackableUserProxyAgent(name=user_proxy_name,
18
+ # system_message="You are Admin",
19
+ # human_input_mode="NEVER",
20
+ # llm_config=llm_config,
21
+ # code_execution_config=False,
22
+ # is_termination_msg=lambda x: x.get("content", "").strip().endswith(
23
+ # "TERMINATE"))
24
+
25
+
26
+ self.user = UserAgent(name=user_proxy_name,llm_config=llm_config)
27
+ self.assistant = ModelAgent(name=assistant_name,
28
+ llm_config=llm_config,
29
+ hf_url="https://api-inference.huggingface.co/models/stabilityai/stable-diffusion-3.5-large",
30
+ system_message="You are a friendly AI assistant. Your job is to generate image with HD quality")
31
+
32
+ self.problem = problem
33
+ self.loop = asyncio.new_event_loop()
34
+ asyncio.set_event_loop(self.loop)
35
+
36
+ # async def initiate_chat(self):
37
+ # await InitChat(self.user, self.assistant, self.problem)
38
+
39
+ def run(self):
40
+ self.loop.run_until_complete(InitChat(self.user, self.assistant, self.problem))
src/usecases/imggene.py DELETED
File without changes