File size: 4,406 Bytes
937f6c7
 
b9744a5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
937f6c7
 
 
 
 
 
 
 
 
b9744a5
937f6c7
 
 
 
b9744a5
 
 
 
 
 
 
 
 
937f6c7
b9744a5
 
 
 
 
 
 
937f6c7
b9744a5
 
 
 
937f6c7
b9744a5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
937f6c7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
import gradio as gr
from huggingface_hub import InferenceClient
import os
from typing import Optional

from datasets import load_dataset
from langchain.schema import Document
from tools.guestinforetriever import GuestInfoRetrieverTool
from tools.search_tool import SearchTool
from tools.weather_tool import WeatherTool


# Load dataset and initialize tools once at module import
try:
    ds = load_dataset("agents-course/unit3-invitees")
    docs = []
    for split in ds.keys():
        for item in ds[split]:
            # attempt to use common text fields, fallback to stringified item
            text = None
            for key in ("text", "content", "body", "description", "name"):
                if key in item and item[key]:
                    text = item[key]
                    break
            if text is None:
                text = str(item)
            docs.append(Document(page_content=str(text), metadata={"source": f"{split}"}))

    guest_tool = GuestInfoRetrieverTool(docs)
    search_tool = SearchTool(docs)
except Exception:
    # dataset load failed; provide empty fallback tools
    docs = []
    guest_tool = None
    search_tool = None

weather_tool = WeatherTool()


def respond(
    message,
    history: list[dict[str, str]],
    system_message,
    max_tokens,
    temperature,
    top_p,
    hf_token: Optional[gr.OAuthToken],
):
    """
    For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
    """
    # simple command routing for tools
    text = message.strip()
    if text.lower().startswith("/guest "):
        query = text[len("/guest "):].strip()
        if guest_tool:
            yield guest_tool.forward(query)
        else:
            yield "Guest retriever not available (dataset failed to load)."
        return

    if text.lower().startswith("/search "):
        query = text[len("/search "):].strip()
        if search_tool:
            yield search_tool.forward(query)
        else:
            yield "Search tool not available (dataset failed to load)."
        return

    if text.lower().startswith("/weather "):
        location = text[len("/weather "):].strip()
        yield weather_tool.forward(location)
        return

    # Default: call the HF chat model
    # Prefer the Gradio OAuth token, fall back to env var `HUGGINGFACEHUB_API_TOKEN`.
    hf_token_value = None
    if hf_token and getattr(hf_token, "token", None):
        hf_token_value = hf_token.token
    else:
        hf_token_value = os.environ.get("HUGGINGFACEHUB_API_TOKEN")

    if not hf_token_value:
        yield (
            "Missing Hugging Face API token. Please run `huggingface-cli login` or set the"
            " environment variable `HUGGINGFACEHUB_API_TOKEN` with a valid token (starts with 'hf_')."
        )
        return

    try:
        client = InferenceClient(token=hf_token_value, model="openai/gpt-oss-20b")
    except Exception as e:
        yield f"Failed to initialize Hugging Face InferenceClient: {e}"
        return

    messages = [{"role": "system", "content": system_message}]
    messages.extend(history)
    messages.append({"role": "user", "content": message})

    response = ""
    for message in client.chat_completion(
        messages,
        max_tokens=max_tokens,
        stream=True,
        temperature=temperature,
        top_p=top_p,
    ):
        choices = message.choices
        token = ""
        if len(choices) and choices[0].delta.content:
            token = choices[0].delta.content

        response += token
        yield response


"""
For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
"""
chatbot = gr.ChatInterface(
    respond,
    additional_inputs=[
        gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
        gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
        gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
        gr.Slider(
            minimum=0.1,
            maximum=1.0,
            value=0.95,
            step=0.05,
            label="Top-p (nucleus sampling)",
        ),
    ],
)

with gr.Blocks() as demo:
    with gr.Sidebar():
        gr.LoginButton()
    chatbot.render()


if __name__ == "__main__":
    demo.launch()