Spaces:
Build error
Build error
| import gradio as gr | |
| from huggingface_hub import InferenceClient | |
| import os | |
| from typing import Optional | |
| from datasets import load_dataset | |
| from langchain.schema import Document | |
| from tools.guestinforetriever import GuestInfoRetrieverTool | |
| from tools.search_tool import SearchTool | |
| from tools.weather_tool import WeatherTool | |
| # Load dataset and initialize tools once at module import | |
| try: | |
| ds = load_dataset("agents-course/unit3-invitees") | |
| docs = [] | |
| for split in ds.keys(): | |
| for item in ds[split]: | |
| # attempt to use common text fields, fallback to stringified item | |
| text = None | |
| for key in ("text", "content", "body", "description", "name"): | |
| if key in item and item[key]: | |
| text = item[key] | |
| break | |
| if text is None: | |
| text = str(item) | |
| docs.append(Document(page_content=str(text), metadata={"source": f"{split}"})) | |
| guest_tool = GuestInfoRetrieverTool(docs) | |
| search_tool = SearchTool(docs) | |
| except Exception: | |
| # dataset load failed; provide empty fallback tools | |
| docs = [] | |
| guest_tool = None | |
| search_tool = None | |
| weather_tool = WeatherTool() | |
| def respond( | |
| message, | |
| history: list[dict[str, str]], | |
| system_message, | |
| max_tokens, | |
| temperature, | |
| top_p, | |
| hf_token: Optional[gr.OAuthToken], | |
| ): | |
| """ | |
| For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference | |
| """ | |
| # simple command routing for tools | |
| text = message.strip() | |
| if text.lower().startswith("/guest "): | |
| query = text[len("/guest "):].strip() | |
| if guest_tool: | |
| yield guest_tool.forward(query) | |
| else: | |
| yield "Guest retriever not available (dataset failed to load)." | |
| return | |
| if text.lower().startswith("/search "): | |
| query = text[len("/search "):].strip() | |
| if search_tool: | |
| yield search_tool.forward(query) | |
| else: | |
| yield "Search tool not available (dataset failed to load)." | |
| return | |
| if text.lower().startswith("/weather "): | |
| location = text[len("/weather "):].strip() | |
| yield weather_tool.forward(location) | |
| return | |
| # Default: call the HF chat model | |
| # Prefer the Gradio OAuth token, fall back to env var `HUGGINGFACEHUB_API_TOKEN`. | |
| hf_token_value = None | |
| if hf_token and getattr(hf_token, "token", None): | |
| hf_token_value = hf_token.token | |
| else: | |
| hf_token_value = os.environ.get("HUGGINGFACEHUB_API_TOKEN") | |
| if not hf_token_value: | |
| yield ( | |
| "Missing Hugging Face API token. Please run `huggingface-cli login` or set the" | |
| " environment variable `HUGGINGFACEHUB_API_TOKEN` with a valid token (starts with 'hf_')." | |
| ) | |
| return | |
| try: | |
| client = InferenceClient(token=hf_token_value, model="openai/gpt-oss-20b") | |
| except Exception as e: | |
| yield f"Failed to initialize Hugging Face InferenceClient: {e}" | |
| return | |
| messages = [{"role": "system", "content": system_message}] | |
| messages.extend(history) | |
| messages.append({"role": "user", "content": message}) | |
| response = "" | |
| for message in client.chat_completion( | |
| messages, | |
| max_tokens=max_tokens, | |
| stream=True, | |
| temperature=temperature, | |
| top_p=top_p, | |
| ): | |
| choices = message.choices | |
| token = "" | |
| if len(choices) and choices[0].delta.content: | |
| token = choices[0].delta.content | |
| response += token | |
| yield response | |
| """ | |
| For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface | |
| """ | |
| chatbot = gr.ChatInterface( | |
| respond, | |
| additional_inputs=[ | |
| gr.Textbox(value="You are a friendly Chatbot.", label="System message"), | |
| gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"), | |
| gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"), | |
| gr.Slider( | |
| minimum=0.1, | |
| maximum=1.0, | |
| value=0.95, | |
| step=0.05, | |
| label="Top-p (nucleus sampling)", | |
| ), | |
| ], | |
| ) | |
| with gr.Blocks() as demo: | |
| with gr.Sidebar(): | |
| gr.LoginButton() | |
| chatbot.render() | |
| if __name__ == "__main__": | |
| demo.launch() | |