3v324v23 commited on
Commit
af954c4
·
1 Parent(s): d081499
Files changed (1) hide show
  1. app.py +51 -69
app.py CHANGED
@@ -11,37 +11,29 @@ def respond(
11
  max_tokens: int,
12
  temperature: float,
13
  top_p: float,
14
- hf_token: gr.OAuthToken,
15
  ):
16
- """
17
- Streams chat responses from a Hugging Face Inference Endpoint.
18
-
19
- Notes:
20
- - Requires your endpoint to allow inference with your token (permission:
21
- `inference.endpoints.infer.write`).
22
- - If the endpoint doesn't support OpenAI-style /v1/chat (e.g., plain TGI),
23
- we fallback to a single-prompt `.text_generation()` call using a simple
24
- prompt format built from the chat history.
25
- """
26
- # 1) Client that talks directly to your endpoint
27
  client = InferenceClient(
28
  base_url=ENDPOINT_URL,
29
- token=hf_token.token, # uses the OAuth token from the LoginButton
30
  )
31
 
32
- # 2) Build OpenAI-style messages for chat backends
33
  messages = []
34
  if system_message:
35
  messages.append({"role": "system", "content": system_message})
36
-
37
- # Gradio gives `history` as a list of {"role": "...", "content": "..."} when type="messages"
38
- # Append previous turns, then the new user message
39
  messages.extend(history or [])
40
  messages.append({"role": "user", "content": user_msg})
41
 
42
- # 3) Try OpenAI-style chat first (works if your endpoint exposes /v1/chat/completions)
43
  try:
44
- response_text = ""
45
  for chunk in client.chat_completion(
46
  messages=messages,
47
  max_tokens=max_tokens,
@@ -49,92 +41,82 @@ def respond(
49
  top_p=top_p,
50
  stream=True,
51
  ):
52
- # chunk.choices[0].delta.content is the streamed token (if present)
53
  token = ""
54
  if getattr(chunk, "choices", None) and getattr(chunk.choices[0], "delta", None):
55
  token = chunk.choices[0].delta.content or ""
56
- response_text += token
57
- yield response_text
58
- return # success via chat api
59
- except Exception as e:
60
- # If chat endpoint isn't available, fall back to text_generation
61
- # (common when the endpoint is plain TGI without OpenAI route enabled)
62
- fallback_reason = str(e)
63
-
64
- # 4) Fallback: Plain text generation with a simple chat-to-prompt adapter
65
  try:
66
- def to_plain_prompt(msgs: List[Dict[str, str]]) -> str:
67
  lines = []
68
  for m in msgs:
69
  role = m.get("role", "user")
70
  content = m.get("content", "")
71
- if role == "system":
72
- lines.append(f"[SYSTEM] {content}")
73
- elif role == "user":
74
- lines.append(f"[USER] {content}")
75
- else:
76
- lines.append(f"[ASSISTANT] {content}")
77
  lines.append("[ASSISTANT]") # cue the model to speak
78
  return "\n".join(lines)
79
 
80
- prompt = to_plain_prompt(messages)
81
 
82
- response_text = ""
83
- # stream text_generation tokens if the backend supports it
84
  for tok in client.text_generation(
85
  prompt,
86
  max_new_tokens=max_tokens,
87
  temperature=temperature,
88
  top_p=top_p,
89
  stream=True,
90
- # Many TGI backends respect these kwargs; safe to include
91
  return_full_text=False,
92
  ):
93
- # `tok` can be a string or an object depending on server; normalize to str
94
  piece = getattr(tok, "token", tok)
95
  if isinstance(piece, dict) and "text" in piece:
96
  piece = piece["text"]
97
- piece = str(piece)
98
- response_text += piece
99
- yield response_text
100
-
101
- except Exception as e2:
102
- # Surface a readable error in the chat window
103
- err = (
104
- "Failed to query the endpoint.\n\n"
105
- f"- Chat attempt error: {fallback_reason}\n"
106
- f"- Text-generation fallback error: {e2}\n\n"
107
- "Check that your endpoint is running, your token has "
108
- "`inference.endpoints.infer.write`, and the runtime supports either "
109
- "OpenAI chat (/v1/chat/completions) or TGI text-generation."
110
- )
111
- yield err
112
 
 
 
113
 
114
- # --- Gradio UI ---
115
- chatbot = gr.ChatInterface(
 
 
 
 
 
 
 
 
 
116
  respond,
117
- type="messages", # history comes as [{"role": "...", "content": "..."}]
118
  additional_inputs=[
119
  gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
120
- gr.Slider(minimum=1, maximum=4096, value=512, step=1, label="Max new tokens"),
121
- gr.Slider(minimum=0.0, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
122
- gr.Slider(minimum=0.0, maximum=1.0, value=0.95, step=0.05, label="Top-p"),
123
  ],
124
  )
125
 
126
  with gr.Blocks() as demo:
127
  with gr.Sidebar():
128
  gr.Markdown("### Hugging Face Login")
129
- # This provides `hf_token: gr.OAuthToken` to `respond`
130
- gr.LoginButton()
131
- gr.Markdown(
132
- "Make sure your token has **`inference.endpoints.infer.write`** permission."
133
- )
134
  gr.Markdown(
135
- f"**Endpoint**:\n\n`{ENDPOINT_URL}`"
 
136
  )
137
- chatbot.render()
 
138
 
139
  if __name__ == "__main__":
140
  demo.launch()
 
11
  max_tokens: int,
12
  temperature: float,
13
  top_p: float,
14
+ hf_token: gr.OAuthToken, # <-- LoginButton injects this
15
  ):
16
+ # 0) Make sure user actually clicked "Login"
17
+ if hf_token is None or not getattr(hf_token, "token", None):
18
+ yield "🔒 Please click **Login** (left sidebar) to authorize Hugging Face access."
19
+ return
20
+
21
+ # 1) Create client against your endpoint (not model=)
 
 
 
 
 
22
  client = InferenceClient(
23
  base_url=ENDPOINT_URL,
24
+ token=hf_token.token, # <-- PAT from Login flow
25
  )
26
 
27
+ # 2) Build messages for chat APIs
28
  messages = []
29
  if system_message:
30
  messages.append({"role": "system", "content": system_message})
 
 
 
31
  messages.extend(history or [])
32
  messages.append({"role": "user", "content": user_msg})
33
 
34
+ # 3) Try OpenAI-style /v1/chat if your endpoint supports it
35
  try:
36
+ out = ""
37
  for chunk in client.chat_completion(
38
  messages=messages,
39
  max_tokens=max_tokens,
 
41
  top_p=top_p,
42
  stream=True,
43
  ):
 
44
  token = ""
45
  if getattr(chunk, "choices", None) and getattr(chunk.choices[0], "delta", None):
46
  token = chunk.choices[0].delta.content or ""
47
+ out += token
48
+ yield out
49
+ return
50
+ except Exception as chat_err:
51
+ chat_err_msg = str(chat_err)
52
+
53
+ # 4) Fallback to plain text-generation (works on vanilla TGI endpoints)
 
 
54
  try:
55
+ def to_prompt(msgs: List[Dict[str, str]]) -> str:
56
  lines = []
57
  for m in msgs:
58
  role = m.get("role", "user")
59
  content = m.get("content", "")
60
+ tag = {"system": "SYSTEM", "user": "USER"}.get(role, "ASSISTANT")
61
+ lines.append(f"[{tag}] {content}")
 
 
 
 
62
  lines.append("[ASSISTANT]") # cue the model to speak
63
  return "\n".join(lines)
64
 
65
+ prompt = to_prompt(messages)
66
 
67
+ out = ""
 
68
  for tok in client.text_generation(
69
  prompt,
70
  max_new_tokens=max_tokens,
71
  temperature=temperature,
72
  top_p=top_p,
73
  stream=True,
 
74
  return_full_text=False,
75
  ):
 
76
  piece = getattr(tok, "token", tok)
77
  if isinstance(piece, dict) and "text" in piece:
78
  piece = piece["text"]
79
+ out += str(piece)
80
+ yield out
81
+
82
+ except Exception as gen_err:
83
+ # 5) Clear, helpful errors for auth/permissions/runtime
84
+ err_text = f"""❗ Failed to query the endpoint.
 
 
 
 
 
 
 
 
 
85
 
86
+ • Chat API error: {chat_err_msg}
87
+ • Text-generation fallback error: {gen_err}
88
 
89
+ Quick checks:
90
+ 1) You clicked **Login** and authorized this app.
91
+ 2) Your HF token includes `inference.endpoints.infer.write`.
92
+ 3) The endpoint is running and supports either OpenAI chat or TGI generation.
93
+ Endpoint: {ENDPOINT_URL}
94
+ """
95
+ yield err_text
96
+
97
+
98
+ # --- UI ---
99
+ chat = gr.ChatInterface(
100
  respond,
101
+ type="messages",
102
  additional_inputs=[
103
  gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
104
+ gr.Slider(1, 4096, value=512, step=1, label="Max new tokens"),
105
+ gr.Slider(0.0, 4.0, value=0.7, step=0.1, label="Temperature"),
106
+ gr.Slider(0.0, 1.0, value=0.95, step=0.05, label="Top-p"),
107
  ],
108
  )
109
 
110
  with gr.Blocks() as demo:
111
  with gr.Sidebar():
112
  gr.Markdown("### Hugging Face Login")
113
+ gr.LoginButton() # <-- keep this
 
 
 
 
114
  gr.Markdown(
115
+ "- Make sure your token has **`inference.endpoints.infer.write`**.\n"
116
+ "- This app will use your HF token only to call the endpoint."
117
  )
118
+ gr.Markdown(f"**Endpoint**: `{ENDPOINT_URL}`")
119
+ chat.render()
120
 
121
  if __name__ == "__main__":
122
  demo.launch()