File size: 6,375 Bytes
e920e8d
29a299d
 
 
 
 
 
 
 
 
 
 
e920e8d
29a299d
 
 
 
d04f615
29a299d
 
d04f615
 
29a299d
e920e8d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29a299d
d04f615
29a299d
d04f615
 
29a299d
 
 
 
 
 
 
 
 
 
 
 
 
d04f615
 
e920e8d
 
 
 
 
29a299d
d04f615
29a299d
 
 
 
 
d04f615
 
 
 
29a299d
e920e8d
29a299d
 
 
 
 
e920e8d
 
 
 
 
 
 
 
 
 
 
 
 
29a299d
e920e8d
 
 
29a299d
d04f615
e920e8d
 
 
 
 
 
 
 
 
29a299d
 
 
e920e8d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import gradio as gr
import mimetypes
import os
import re
import shutil
from typing import Optional

from smolagents.agent_types import AgentAudio, AgentImage, AgentText, handle_agent_output_types
from smolagents.agents import ActionStep, MultiStepAgent
from smolagents.memory import MemoryStep
from smolagents.utils import _is_package_available

# ... (keep your existing pull_messages_from_step and stream_to_gradio functions as they are)

class GradioUI:
    def __init__(self, agent: MultiStepAgent, file_upload_folder: str | None = None):
        if not _is_package_available("gradio"):
            raise ModuleNotFoundError("Please install 'gradio' extra to use the GradioUI: `pip install 'smolagents[gradio]'`")
        self.agent = agent
        self.file_upload_folder = file_upload_folder
        if self.file_upload_folder is not None and not os.path.exists(file_upload_folder):
            os.mkdir(file_upload_folder)

    # Your existing interact_with_agent, upload_file, and log_user_message methods remain the same
    def interact_with_agent(self, prompt, chat_history): # Renamed 'messages' to 'chat_history' for clarity with ChatInterface
        # The chat_history from ChatInterface is a list of [user_message, agent_response] tuples
        # You'll need to adapt your processing slightly
        yield chat_history # Yield initial history to show user message

        # The prompt here will already be the user's input, so no need to append it to messages again
        # For demonstration, I'll assume stream_to_gradio directly yields content for the chatbot
        # You might need to adjust stream_to_gradio if it yields gr.ChatMessage objects directly,
        # as gr.ChatInterface expects tuples.

        # Example adaptation (might need further refinement based on stream_to_gradio's exact output)
        full_response_content = ""
        for msg_obj in stream_to_gradio(self.agent, task=prompt, reset_agent_memory=False):
            if isinstance(msg_obj, gr.ChatMessage):
                # For simplicity, concatenate text content. For images/audio, you'd need more complex handling
                if isinstance(msg_obj.content, str):
                    full_response_content += msg_obj.content + "\n"
                elif isinstance(msg_obj.content, dict) and 'path' in msg_obj.content:
                    # Handle image/audio paths
                    full_response_content += f"[{msg_obj.content['mime_type']} at {msg_obj.content['path']}]\n"

            # Update the last assistant message in the chat history
            if chat_history and chat_history[-1][1] is None: # If the last assistant message is empty
                chat_history[-1][1] = full_response_content
            else:
                chat_history.append([prompt, full_response_content]) # Append new turn
            yield chat_history


    def upload_file(self, file, file_uploads_log, allowed_file_types=None):
        import gradio as gr
        if allowed_file_types is None:
            allowed_file_types = ["application/pdf", "application/vnd.openxmlformats-officedocument.wordprocessingml.document", "text/plain"]

        if file is None:
            return gr.Textbox("No file uploaded", visible=True), file_uploads_log

        try:
            mime_type, _ = mimetypes.guess_type(file.name)
        except Exception as e:
            return gr.Textbox(f"Error: {e}", visible=True), file_uploads_log

        if mime_type not in allowed_file_types:
            return gr.Textbox("File type disallowed", visible=True), file_uploads_log

        original_name = os.path.basename(file.name)
        sanitized_name = re.sub(r"[^\w\-.]", "_", original_name)
        ext_map = {v: k for k, v in mimetypes.types_map.items()}
        # Fix for sanitized_name generation
        base_name, ext = os.path.splitext(original_name)
        if not ext: # No extension, use 'txt' as default if mime_type is not specific
             ext = "." + ext_map.get(mime_type, "txt")
        sanitized_name = re.sub(r"[^\w\-.]", "_", base_name) + ext

        file_path = os.path.join(self.file_upload_folder, sanitized_name)
        shutil.copy(file.name, file_path)

        return gr.Textbox(f"File uploaded: {file_path}", visible=True), file_uploads_log + [file_path]

    def log_user_message(self, text_input, file_uploads_log):
        context = text_input
        if file_uploads_log:
            context += f"\nAttached files: {file_uploads_log}"
        return context, ""


    def launch(self, **kwargs):
        import gradio as gr

        with gr.Blocks(fill_height=True) as demo:
            file_uploads_log = gr.State([])

            # Use gr.ChatInterface directly for the main chat
            # This function will be called with the user's message and the current chat history
            gr.ChatInterface(
                fn=self.interact_with_agent,
                chatbot=gr.Chatbot(
                    label="Agent",
                    avatar_images=(
                        None,
                        "https://huggingface.co/datasets/agents-course/course-images/resolve/main/en/communication/Alfred.png",
                    ),
                    resizeable=True,
                    scale=1,
                ),
                textbox=gr.Textbox(lines=1, label="Chat Message"),
                title="Agent Chat", # You can set a title for your app
                # Additional components can be added here using `gr.Row`, `gr.Column`, etc.
            )

            # Add file upload outside ChatInterface if desired, and link it
            if self.file_upload_folder is not None:
                with gr.Row():
                    upload_file = gr.File(label="Upload a file", file_count="multiple") # Allow multiple files
                    upload_status = gr.Textbox(label="Upload Status", interactive=False, visible=False)
                upload_file.upload(self.upload_file, [upload_file, file_uploads_log], [upload_status, file_uploads_log])
                # You'll need to figure out how to pass file_uploads_log to interact_with_agent
                # One way is to modify interact_with_agent to accept it, or use a global/class variable if appropriate.
                # For now, I'm just showing how to add it to the UI.

        demo.launch(debug=True, share=True, **kwargs)

__all__ = ["stream_to_gradio", "GradioUI"]