Update app.py
Browse files
app.py
CHANGED
|
@@ -14,24 +14,24 @@ import torch
|
|
| 14 |
import cv2
|
| 15 |
from gradio_client import Client, file
|
| 16 |
|
|
|
|
| 17 |
def image_gen(prompt):
|
| 18 |
client = Client("KingNish/Image-Gen-Pro")
|
| 19 |
-
return client.predict("Image Generation",None, prompt, api_name="/image_gen_pro")
|
| 20 |
|
|
|
|
| 21 |
model_id = "llava-hf/llava-interleave-qwen-0.5b-hf"
|
| 22 |
-
|
| 23 |
processor = LlavaProcessor.from_pretrained(model_id)
|
| 24 |
-
|
| 25 |
model = LlavaForConditionalGeneration.from_pretrained(model_id)
|
| 26 |
model.to("cpu")
|
| 27 |
|
| 28 |
-
|
| 29 |
def llava(message, history):
|
| 30 |
if message["files"]:
|
| 31 |
image = message["files"][0]
|
| 32 |
else:
|
| 33 |
for hist in history:
|
| 34 |
-
if type(hist[0])==tuple:
|
| 35 |
image = hist[0][0]
|
| 36 |
|
| 37 |
txt = message["text"]
|
|
@@ -43,12 +43,14 @@ def llava(message, history):
|
|
| 43 |
inputs = processor(prompt, image, return_tensors="pt")
|
| 44 |
return inputs
|
| 45 |
|
|
|
|
| 46 |
def extract_text_from_webpage(html_content):
|
| 47 |
soup = BeautifulSoup(html_content, 'html.parser')
|
| 48 |
for tag in soup(["script", "style", "header", "footer"]):
|
| 49 |
tag.extract()
|
| 50 |
return soup.get_text(strip=True)
|
| 51 |
|
|
|
|
| 52 |
def search(query):
|
| 53 |
term = query
|
| 54 |
start = 0
|
|
@@ -88,8 +90,8 @@ client_yi = InferenceClient("01-ai/Yi-1.5-34B-Chat")
|
|
| 88 |
# Define the main chat function
|
| 89 |
def respond(message, history):
|
| 90 |
func_caller = []
|
| 91 |
-
|
| 92 |
user_prompt = message
|
|
|
|
| 93 |
# Handle image processing
|
| 94 |
if message["files"]:
|
| 95 |
inputs = llava(message, history)
|
|
@@ -101,9 +103,11 @@ def respond(message, history):
|
|
| 101 |
|
| 102 |
buffer = ""
|
| 103 |
for new_text in streamer:
|
| 104 |
-
|
| 105 |
-
|
|
|
|
| 106 |
else:
|
|
|
|
| 107 |
functions_metadata = [
|
| 108 |
{"type": "function", "function": {"name": "web_search", "description": "Search query on google", "parameters": {"type": "object", "properties": {"query": {"type": "string", "description": "web search query"}}, "required": ["query"]}}},
|
| 109 |
{"type": "function", "function": {"name": "general_query", "description": "Reply general query of USER", "parameters": {"type": "object", "properties": {"prompt": {"type": "string", "description": "A detailed prompt"}}, "required": ["prompt"]}}},
|
|
@@ -120,45 +124,41 @@ def respond(message, history):
|
|
| 120 |
|
| 121 |
response = client_gemma.chat_completion(func_caller, max_tokens=200)
|
| 122 |
response = str(response)
|
|
|
|
|
|
|
| 123 |
try:
|
| 124 |
response = response[int(response.find("{")):int(response.rindex("</"))]
|
| 125 |
except:
|
| 126 |
response = response[int(response.find("{")):(int(response.rfind("}"))+1)]
|
| 127 |
-
response = response.replace("\\n", "")
|
| 128 |
-
response = response.replace("\\'", "'")
|
| 129 |
-
response = response.replace('\\"', '"')
|
| 130 |
-
response = response.replace('\\', '')
|
| 131 |
print(f"\n{response}")
|
| 132 |
|
| 133 |
try:
|
| 134 |
json_data = json.loads(str(response))
|
| 135 |
if json_data["name"] == "web_search":
|
| 136 |
query = json_data["arguments"]["query"]
|
| 137 |
-
# gr.Info("Searching Web")
|
| 138 |
web_results = search(query)
|
| 139 |
-
# gr.Info("Extracting relevant Info")
|
| 140 |
web2 = ' '.join([f"Link: {res['link']}\nText: {res['text']}\n\n" for res in web_results])
|
| 141 |
messages = f"<|im_start|>system\n Hi π, I am Nora,mini a helpful assistant.Ask me! I will do my best!! <|im_end|>"
|
| 142 |
for msg in history:
|
| 143 |
messages += f"\n<|im_start|>user\n{str(msg[0])}<|im_end|>"
|
| 144 |
messages += f"\n<|im_start|>assistant\n{str(msg[1])}<|im_end|>"
|
| 145 |
-
messages+=f"\n<|im_start|>user\n{message_text}<|im_end|>\n<|im_start|>web_result\n{web2}<|im_end|>\n<|im_start|>assistant\n"
|
| 146 |
stream = client_mixtral.text_generation(messages, max_new_tokens=2000, do_sample=True, stream=True, details=True, return_full_text=False)
|
| 147 |
output = ""
|
| 148 |
for response in stream:
|
| 149 |
-
if not response.token.text
|
| 150 |
output += response.token.text
|
| 151 |
yield output
|
| 152 |
elif json_data["name"] == "image_generation":
|
| 153 |
query = json_data["arguments"]["query"]
|
| 154 |
-
gr.Info("Generating Image, Please wait 10 sec...")
|
| 155 |
yield "Generating Image, Please wait 10 sec..."
|
| 156 |
try:
|
| 157 |
image = image_gen(f"{str(query)}")
|
| 158 |
yield gr.Image(image[1])
|
| 159 |
except:
|
| 160 |
client_sd3 = InferenceClient("stabilityai/stable-diffusion-3-medium-diffusers")
|
| 161 |
-
seed = random.randint(0,999999)
|
| 162 |
image = client_sd3.text_to_image(query, negative_prompt=f"{seed}")
|
| 163 |
yield gr.Image(image)
|
| 164 |
elif json_data["name"] == "image_qna":
|
|
@@ -168,33 +168,35 @@ def respond(message, history):
|
|
| 168 |
|
| 169 |
thread = Thread(target=model.generate, kwargs=generation_kwargs)
|
| 170 |
thread.start()
|
| 171 |
-
|
| 172 |
buffer = ""
|
| 173 |
for new_text in streamer:
|
| 174 |
-
|
| 175 |
-
|
|
|
|
| 176 |
else:
|
| 177 |
messages = f"<|im_start|>system\n π, I am Nora,mini a helpful assistant.Ask me! I will do my best!!<|im_end|>"
|
| 178 |
for msg in history:
|
| 179 |
messages += f"\n<|im_start|>user\n{str(msg[0])}<|im_end|>"
|
| 180 |
messages += f"\n<|im_start|>assistant\n{str(msg[1])}<|im_end|>"
|
| 181 |
-
messages+=f"\n<|im_start|>user\n{message_text}<|im_end|>\n<|im_start|>assistant\n"
|
| 182 |
stream = client_yi.text_generation(messages, max_new_tokens=2000, do_sample=True, stream=True, details=True, return_full_text=False)
|
| 183 |
output = ""
|
| 184 |
for response in stream:
|
| 185 |
-
if
|
| 186 |
output += response.token.text
|
| 187 |
yield output
|
| 188 |
except:
|
| 189 |
-
|
|
|
|
| 190 |
for msg in history:
|
| 191 |
-
messages += f"\n<|
|
| 192 |
-
messages += f"\n<|
|
| 193 |
-
messages+=f"\n<|
|
| 194 |
stream = client_llama.text_generation(messages, max_new_tokens=2000, do_sample=True, stream=True, details=True, return_full_text=False)
|
| 195 |
output = ""
|
| 196 |
for response in stream:
|
| 197 |
-
if
|
| 198 |
output += response.token.text
|
| 199 |
yield output
|
| 200 |
|
|
@@ -205,6 +207,9 @@ demo = gr.ChatInterface(
|
|
| 205 |
textbox=gr.MultimodalTextbox(),
|
| 206 |
multimodal=True,
|
| 207 |
concurrency_limit=200,
|
| 208 |
-
cache_examples=False,
|
|
|
|
| 209 |
)
|
|
|
|
|
|
|
| 210 |
demo.launch()
|
|
|
|
| 14 |
import cv2
|
| 15 |
from gradio_client import Client, file
|
| 16 |
|
| 17 |
+
# Function to generate an image using another model
|
| 18 |
def image_gen(prompt):
|
| 19 |
client = Client("KingNish/Image-Gen-Pro")
|
| 20 |
+
return client.predict("Image Generation", None, prompt, api_name="/image_gen_pro")
|
| 21 |
|
| 22 |
+
# Load the processor and model for image-based QnA (LLaVA model)
|
| 23 |
model_id = "llava-hf/llava-interleave-qwen-0.5b-hf"
|
|
|
|
| 24 |
processor = LlavaProcessor.from_pretrained(model_id)
|
|
|
|
| 25 |
model = LlavaForConditionalGeneration.from_pretrained(model_id)
|
| 26 |
model.to("cpu")
|
| 27 |
|
| 28 |
+
# Function to process images with text input
|
| 29 |
def llava(message, history):
|
| 30 |
if message["files"]:
|
| 31 |
image = message["files"][0]
|
| 32 |
else:
|
| 33 |
for hist in history:
|
| 34 |
+
if type(hist[0]) == tuple:
|
| 35 |
image = hist[0][0]
|
| 36 |
|
| 37 |
txt = message["text"]
|
|
|
|
| 43 |
inputs = processor(prompt, image, return_tensors="pt")
|
| 44 |
return inputs
|
| 45 |
|
| 46 |
+
# Helper function to extract text from a webpage
|
| 47 |
def extract_text_from_webpage(html_content):
|
| 48 |
soup = BeautifulSoup(html_content, 'html.parser')
|
| 49 |
for tag in soup(["script", "style", "header", "footer"]):
|
| 50 |
tag.extract()
|
| 51 |
return soup.get_text(strip=True)
|
| 52 |
|
| 53 |
+
# Function to search the web using Google
|
| 54 |
def search(query):
|
| 55 |
term = query
|
| 56 |
start = 0
|
|
|
|
| 90 |
# Define the main chat function
|
| 91 |
def respond(message, history):
|
| 92 |
func_caller = []
|
|
|
|
| 93 |
user_prompt = message
|
| 94 |
+
|
| 95 |
# Handle image processing
|
| 96 |
if message["files"]:
|
| 97 |
inputs = llava(message, history)
|
|
|
|
| 103 |
|
| 104 |
buffer = ""
|
| 105 |
for new_text in streamer:
|
| 106 |
+
if new_text not in ["<|im_end|>", "<|endoftext|>"]: # Ignore special tokens
|
| 107 |
+
buffer += new_text
|
| 108 |
+
yield buffer
|
| 109 |
else:
|
| 110 |
+
# Functions metadata for invoking different models or functions
|
| 111 |
functions_metadata = [
|
| 112 |
{"type": "function", "function": {"name": "web_search", "description": "Search query on google", "parameters": {"type": "object", "properties": {"query": {"type": "string", "description": "web search query"}}, "required": ["query"]}}},
|
| 113 |
{"type": "function", "function": {"name": "general_query", "description": "Reply general query of USER", "parameters": {"type": "object", "properties": {"prompt": {"type": "string", "description": "A detailed prompt"}}, "required": ["prompt"]}}},
|
|
|
|
| 124 |
|
| 125 |
response = client_gemma.chat_completion(func_caller, max_tokens=200)
|
| 126 |
response = str(response)
|
| 127 |
+
|
| 128 |
+
# Filtering and processing response
|
| 129 |
try:
|
| 130 |
response = response[int(response.find("{")):int(response.rindex("</"))]
|
| 131 |
except:
|
| 132 |
response = response[int(response.find("{")):(int(response.rfind("}"))+1)]
|
| 133 |
+
response = response.replace("\\n", "").replace("\\'", "'").replace('\\"', '"').replace('\\', '')
|
|
|
|
|
|
|
|
|
|
| 134 |
print(f"\n{response}")
|
| 135 |
|
| 136 |
try:
|
| 137 |
json_data = json.loads(str(response))
|
| 138 |
if json_data["name"] == "web_search":
|
| 139 |
query = json_data["arguments"]["query"]
|
|
|
|
| 140 |
web_results = search(query)
|
|
|
|
| 141 |
web2 = ' '.join([f"Link: {res['link']}\nText: {res['text']}\n\n" for res in web_results])
|
| 142 |
messages = f"<|im_start|>system\n Hi π, I am Nora,mini a helpful assistant.Ask me! I will do my best!! <|im_end|>"
|
| 143 |
for msg in history:
|
| 144 |
messages += f"\n<|im_start|>user\n{str(msg[0])}<|im_end|>"
|
| 145 |
messages += f"\n<|im_start|>assistant\n{str(msg[1])}<|im_end|>"
|
| 146 |
+
messages += f"\n<|im_start|>user\n{message_text}<|im_end|>\n<|im_start|>web_result\n{web2}<|im_end|>\n<|im_start|>assistant\n"
|
| 147 |
stream = client_mixtral.text_generation(messages, max_new_tokens=2000, do_sample=True, stream=True, details=True, return_full_text=False)
|
| 148 |
output = ""
|
| 149 |
for response in stream:
|
| 150 |
+
if not response.token.text in ["<|im_end|>", "<|endoftext|>"]: # Exclude special tokens
|
| 151 |
output += response.token.text
|
| 152 |
yield output
|
| 153 |
elif json_data["name"] == "image_generation":
|
| 154 |
query = json_data["arguments"]["query"]
|
|
|
|
| 155 |
yield "Generating Image, Please wait 10 sec..."
|
| 156 |
try:
|
| 157 |
image = image_gen(f"{str(query)}")
|
| 158 |
yield gr.Image(image[1])
|
| 159 |
except:
|
| 160 |
client_sd3 = InferenceClient("stabilityai/stable-diffusion-3-medium-diffusers")
|
| 161 |
+
seed = random.randint(0, 999999)
|
| 162 |
image = client_sd3.text_to_image(query, negative_prompt=f"{seed}")
|
| 163 |
yield gr.Image(image)
|
| 164 |
elif json_data["name"] == "image_qna":
|
|
|
|
| 168 |
|
| 169 |
thread = Thread(target=model.generate, kwargs=generation_kwargs)
|
| 170 |
thread.start()
|
| 171 |
+
|
| 172 |
buffer = ""
|
| 173 |
for new_text in streamer:
|
| 174 |
+
if new_text not in ["<|im_end|>", "<|endoftext|>"]: # Ignore special tokens
|
| 175 |
+
buffer += new_text
|
| 176 |
+
yield buffer
|
| 177 |
else:
|
| 178 |
messages = f"<|im_start|>system\n π, I am Nora,mini a helpful assistant.Ask me! I will do my best!!<|im_end|>"
|
| 179 |
for msg in history:
|
| 180 |
messages += f"\n<|im_start|>user\n{str(msg[0])}<|im_end|>"
|
| 181 |
messages += f"\n<|im_start|>assistant\n{str(msg[1])}<|im_end|>"
|
| 182 |
+
messages += f"\n<|im_start|>user\n{message_text}<|im_end|>\n<|im_start|>assistant\n"
|
| 183 |
stream = client_yi.text_generation(messages, max_new_tokens=2000, do_sample=True, stream=True, details=True, return_full_text=False)
|
| 184 |
output = ""
|
| 185 |
for response in stream:
|
| 186 |
+
if response.token.text not in ["<|im_end|>", "<|endoftext|>"]: # Ignore special tokens
|
| 187 |
output += response.token.text
|
| 188 |
yield output
|
| 189 |
except:
|
| 190 |
+
# Handle the case where JSON parsing or function calling fails
|
| 191 |
+
messages = f"<|im_start|>system\nHi π, I am Nora,mini a helpful assistant.Ask me! I will do my best!!<|im_end|>"
|
| 192 |
for msg in history:
|
| 193 |
+
messages += f"\n<|im_start|>user\n{str(msg[0])}<|im_end|>"
|
| 194 |
+
messages += f"\n<|im_start|>assistant\n{str(msg[1])}<|im_end|>"
|
| 195 |
+
messages += f"\n<|im_start|>user\n{message_text}<|im_end|>\n<|im_start|>assistant\n"
|
| 196 |
stream = client_llama.text_generation(messages, max_new_tokens=2000, do_sample=True, stream=True, details=True, return_full_text=False)
|
| 197 |
output = ""
|
| 198 |
for response in stream:
|
| 199 |
+
if response.token.text not in ["<|eot_id|>", "<|im_end|>"]: # Ignore special tokens
|
| 200 |
output += response.token.text
|
| 201 |
yield output
|
| 202 |
|
|
|
|
| 207 |
textbox=gr.MultimodalTextbox(),
|
| 208 |
multimodal=True,
|
| 209 |
concurrency_limit=200,
|
| 210 |
+
cache_examples=False,
|
| 211 |
+
css="footer{display:none !important}"
|
| 212 |
)
|
| 213 |
+
|
| 214 |
+
# Launch the Gradio app
|
| 215 |
demo.launch()
|