spagestic's picture
addded file name headers
dd8e099
Raw
History Blame Contribute Delete
9.21 kB
# ui.py
import mimetypes
import os
import re
import shutil
import gradio as gr
from smolagents.gradio_ui import stream_to_gradio
from agent_factory import create_agent
_AVATAR = (
None,
"https://huggingface.co/datasets/huggingface/documentation-images/"
"resolve/main/smolagents/mascot_smol.png",
)
class GradioUI:
"""Gradio-based chat interface for the open-Deep-Research agent."""
def __init__(self, file_upload_folder: str | None = None):
self.file_upload_folder = file_upload_folder
if self.file_upload_folder and not os.path.exists(file_upload_folder):
os.mkdir(file_upload_folder)
# ── Agent interaction ─────────────────────────────────────────────────────
def interact_with_agent(self, prompt, messages, session_state):
if "agent" not in session_state:
session_state["agent"] = create_agent()
try:
has_memory = hasattr(session_state["agent"], "memory")
print(f"Agent has memory: {has_memory}")
if has_memory:
print(f"Memory type: {type(session_state['agent'].memory)}")
messages.append(gr.ChatMessage(role="user", content=prompt))
yield messages
for msg in stream_to_gradio(
session_state["agent"], task=prompt, reset_agent_memory=False
):
messages.append(msg)
yield messages
yield messages
except Exception as e:
print(f"Error in interaction: {str(e)}")
raise
# ── File upload ───────────────────────────────────────────────────────────
def upload_file(
self,
file,
file_uploads_log,
allowed_file_types=(
"application/pdf",
"application/vnd.openxmlformats-officedocument.wordprocessingml.document",
"text/plain",
),
):
"""Validate, sanitize, and persist an uploaded file."""
if file is None:
return gr.Textbox("No file uploaded", visible=True), file_uploads_log
try:
mime_type, _ = mimetypes.guess_type(file.name)
except Exception as e:
return gr.Textbox(f"Error: {e}", visible=True), file_uploads_log
if mime_type not in allowed_file_types:
return gr.Textbox("File type disallowed", visible=True), file_uploads_log
original_name = os.path.basename(file.name)
sanitized_name = re.sub(r"[^\w\-.]", "_", original_name)
type_to_ext = {t: ext for ext, t in mimetypes.types_map.items()}
parts = sanitized_name.split(".")[:-1]
parts.append(type_to_ext[mime_type])
sanitized_name = "".join(parts)
file_path = os.path.join(
self.file_upload_folder, os.path.basename(sanitized_name)
)
shutil.copy(file.name, file_path)
return (
gr.Textbox(f"File uploaded: {file_path}", visible=True),
file_uploads_log + [file_path],
)
# ── Message helpers ───────────────────────────────────────────────────────
def log_user_message(self, text_input, file_uploads_log):
suffix = (
f"\nYou have been provided with these files, which might be helpful "
f"or not: {file_uploads_log}"
if file_uploads_log
else ""
)
return (
text_input + suffix,
gr.Textbox(
value="",
interactive=False,
placeholder="Please wait while Steps are getting populated",
),
gr.Button(interactive=False),
)
# ── Device detection ──────────────────────────────────────────────────────
def detect_device(self, request: gr.Request) -> str:
if not request:
return "Unknown device"
is_mobile_header = request.headers.get("sec-ch-ua-mobile")
if is_mobile_header:
return "Mobile" if "?1" in is_mobile_header else "Desktop"
ua = request.headers.get("user-agent", "").lower()
if any(k in ua for k in ("android", "iphone", "ipad", "mobile", "phone")):
return "Mobile"
platform = request.headers.get("sec-ch-ua-platform", "").lower()
if platform in ('"android"', '"ios"'):
return "Mobile"
if platform in ('"windows"', '"macos"', '"linux"'):
return "Desktop"
return "Desktop"
# ── Layout helpers ────────────────────────────────────────────────────────
def _reset_inputs_fn(self):
return (
gr.Textbox(
interactive=True,
placeholder="Enter your prompt here and press the button",
),
gr.Button(interactive=True),
)
def _wire_events(
self, text_input, launch_btn, file_uploads_log,
chatbot, session_state, stored_messages
):
"""Attach submit/click event chains to inputs."""
for trigger in (text_input.submit, launch_btn.click):
trigger(
self.log_user_message,
[text_input, file_uploads_log],
[stored_messages, text_input, launch_btn],
).then(
self.interact_with_agent,
[stored_messages, chatbot, session_state],
[chatbot],
).then(self._reset_inputs_fn, None, [text_input, launch_btn])
def _upload_widget(self, file_uploads_log):
"""Render upload widgets when a folder is configured."""
if self.file_upload_folder is None:
return
upload_file = gr.File(label="Upload a file")
upload_status = gr.Textbox(
label="Upload Status", interactive=False, visible=False
)
upload_file.change(
self.upload_file,
[upload_file, file_uploads_log],
[upload_status, file_uploads_log],
)
def _desktop_layout(self):
file_uploads_log = gr.State([])
with gr.Sidebar():
gr.Markdown(
"""# open Deep Research - free the AI agents!"""
)
with gr.Group():
gr.Markdown("**Your request**", container=True)
text_input = gr.Textbox(
lines=3,
label="Your request",
container=False,
placeholder="Enter your prompt here and press Shift+Enter or press the button",
)
launch_btn = gr.Button("Run", variant="primary")
self._upload_widget(file_uploads_log)
session_state = gr.State({})
stored_messages = gr.State([])
chatbot = gr.Chatbot(
label="open-Deep-Research",
type="messages",
avatar_images=_AVATAR,
resizeable=False,
scale=1,
elem_id="my-chatbot",
)
self._wire_events(
text_input, launch_btn, file_uploads_log,
chatbot, session_state, stored_messages
)
def _mobile_layout(self):
gr.Markdown(
"""# open Deep Research - free the AI agents!
"""
)
session_state = gr.State({})
stored_messages = gr.State([])
file_uploads_log = gr.State([])
chatbot = gr.Chatbot(
label="open-Deep-Research",
type="messages",
avatar_images=_AVATAR,
resizeable=True,
scale=1,
)
self._upload_widget(file_uploads_log)
text_input = gr.Textbox(
lines=1,
label="Your request",
placeholder="Enter your prompt here and press the button",
)
launch_btn = gr.Button("Run", variant="primary")
self._wire_events(
text_input, launch_btn, file_uploads_log,
chatbot, session_state, stored_messages
)
# ── Launch ────────────────────────────────────────────────────────────────
def launch(self, **kwargs):
with gr.Blocks(theme="ocean", fill_height=True) as demo:
@gr.render()
def layout(request: gr.Request):
device = self.detect_device(request)
print(f"device - {device}")
with gr.Blocks(fill_height=True):
if device == "Desktop":
self._desktop_layout()
else:
self._mobile_layout()
demo.launch(debug=True, **kwargs)