File size: 4,311 Bytes
98eadee
 
7dad0a9
98eadee
 
 
 
 
 
b60854a
 
98eadee
c5e7c2f
7dad0a9
 
 
c158e59
db2bcbc
 
7dad0a9
 
2731988
b60854a
7dad0a9
2731988
 
c158e59
2731988
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c158e59
2731988
 
c158e59
2731988
7dad0a9
 
98eadee
 
2731988
98eadee
 
 
 
 
9b10c1a
14865ad
9b10c1a
98eadee
9b10c1a
2731988
9b10c1a
7dad0a9
 
98eadee
 
 
 
7dad0a9
98eadee
 
c5e7c2f
 
b60854a
98eadee
9b10c1a
 
 
b60854a
9b10c1a
98eadee
7dad0a9
98eadee
 
7dad0a9
98eadee
9b10c1a
c5e7c2f
98eadee
 
c158e59
 
 
98eadee
 
c5e7c2f
2731988
 
 
7dad0a9
98eadee
 
 
c5e7c2f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import os
import gradio as gr
from typing import Iterator, List, Dict, Any, Tuple

from backend_hf_api import HFInferenceBackend, is_hf_api_available

SYSTEM_PROMPT_DEFAULT = os.getenv("SYSTEM_PROMPT", "You are a helpful assistant. Be concise and accurate.")
DEFAULT_MAX_NEW_TOKENS = int(os.getenv("MAX_NEW_TOKENS", "512"))
DEFAULT_TEMPERATURE = float(os.getenv("TEMPERATURE", "0.7"))
# Use a valid Nemotron repo by default; override via Space Variables if you want another.
DEFAULT_HF_API_MODEL = os.getenv("HF_API_MODEL", "NVIDIA/Nemotron-3-8B-Instruct")


def _msg_content_to_text(content: Any) -> str:
    if isinstance(content, str):
        return content
    if isinstance(content, dict) and isinstance(content.get("text"), str):
        return content["text"]
    return "" if content is None else str(content)


def _history_to_pairs(history: Any) -> List[Tuple[str, str]]:
    """Gradio v6 messages or legacy (user, assistant) pairs → (user, assistant) pairs."""
    pairs: List[Tuple[str, str]] = []
    if not history:
        return pairs

    if isinstance(history[0], dict):
        pending_user: str | None = None
        for m in history:
            role = m.get("role")
            text = _msg_content_to_text(m.get("content"))
            if role == "user":
                if pending_user is not None:
                    pairs.append((pending_user, ""))
                pending_user = text
            elif role == "assistant":
                if pending_user is None:
                    pairs.append(("", text))
                else:
                    pairs.append((pending_user, text))
                    pending_user = None
        if pending_user is not None:
            pairs.append((pending_user, ""))
        return pairs

    if isinstance(history[0], (list, tuple)) and len(history[0]) == 2:
        return [(str(u or ""), str(a or "")) for (u, a) in history]

    return [(str(history), "")]


def chat_fn(
    message: str,
    history: List[Dict[str, Any]] | List[Tuple[str, str]],
    model_name: str,
    system_prompt: str,
    temperature: float,
    max_new_tokens: int,
) -> Iterator[str]:
    if not is_hf_api_available():
        yield "[error] HF_TOKEN not set. Add it in Spaces → Settings → Secrets and restart."
        return
    try:
        backend = HFInferenceBackend(model_name or DEFAULT_HF_API_MODEL)
        pairs_history = _history_to_pairs(history)
        yield from backend.generate_stream(
            system_prompt=(system_prompt or SYSTEM_PROMPT_DEFAULT).strip(),
            history=pairs_history,
            user_msg=message,
            temperature=float(temperature),
            max_new_tokens=int(max_new_tokens),
        )
    except Exception as e:
        yield f"[error] {type(e).__name__}: {e}"


with gr.Blocks() as demo:
    gr.Markdown("# 🤖 HF Inference API Chatbot (Gradio v6)\nUses your **HF_TOKEN**. Preflight checks model to prevent crashes.")

    model_name = gr.Textbox(
        value=DEFAULT_HF_API_MODEL,
        label="HF model repo",
        placeholder="e.g., NVIDIA/Nemotron-3-8B-Instruct",
    )

    with gr.Accordion("Advanced", open=False) as adv:
        system_prompt = gr.Textbox(value=SYSTEM_PROMPT_DEFAULT, label="System prompt", lines=3)
        temperature = gr.Slider(0.0, 1.5, value=DEFAULT_TEMPERATURE, step=0.05, label="Temperature")
        max_new_tokens = gr.Slider(16, 4096, value=DEFAULT_MAX_NEW_TOKENS, step=16, label="Max new tokens")

    gr.ChatInterface(
        fn=chat_fn,
        title="Chat",
        examples=[
            ["Summarize why the sky is blue in 3 sentences.", DEFAULT_HF_API_MODEL, SYSTEM_PROMPT_DEFAULT, DEFAULT_TEMPERATURE, DEFAULT_MAX_NEW_TOKENS],
            ["Draft a friendly product blurb for a coffee mug.", DEFAULT_HF_API_MODEL, SYSTEM_PROMPT_DEFAULT, DEFAULT_TEMPERATURE, DEFAULT_MAX_NEW_TOKENS],
            ["Explain binary search with a tiny Python example.", DEFAULT_HF_API_MODEL, SYSTEM_PROMPT_DEFAULT, DEFAULT_TEMPERATURE, DEFAULT_MAX_NEW_TOKENS],
        ],
        cache_examples=False,
        additional_inputs=[model_name, system_prompt, temperature, max_new_tokens],
        additional_inputs_accordion=adv,
        save_history=True,
        editable=True,
        autoscroll=True,
    )

if __name__ == "__main__":
    demo.queue().launch()