akhaliq HF Staff commited on
Commit
cd17de3
·
1 Parent(s): fc50e95

Pass workflow OAuth token to InferenceClient via ContextVar

Browse files
Files changed (1) hide show
  1. app.py +22 -5
app.py CHANGED
@@ -4,12 +4,20 @@ import gradio as gr
4
  from huggingface_hub import InferenceClient
5
  from huggingface_hub import get_token as hf_get_token
6
  from gradio.context import LocalContext
 
 
 
 
7
 
8
  def get_hf_token() -> str | None:
9
  """
10
- Retrieves the HF API token from either the user's Gradio OAuth session
11
- or falls back to the system environment/CLI cached token.
12
  """
 
 
 
 
13
  request = LocalContext.request.get(None)
14
  if request is not None:
15
  session = getattr(request, "session", {})
@@ -31,8 +39,12 @@ def generate_prompt(concept: str) -> str:
31
  if not concept:
32
  return "a ginger cat wearing a tiny wizard hat reading a spellbook"
33
  try:
34
- token = get_hf_token()
35
- client = InferenceClient("nvidia/NVIDIA-Nemotron-3-Ultra-550B-A55B-NVFP4", token=token)
 
 
 
 
36
  system_instruction = (
37
  "You are an expert prompt engineer for text-to-image models. "
38
  "Your task is to take a simple concept and expand it into a detailed, "
@@ -46,6 +58,7 @@ def generate_prompt(concept: str) -> str:
46
  {"role": "user", "content": f"Concept: {concept}"}
47
  ]
48
  response = client.chat_completion(
 
49
  messages=messages,
50
  temperature=0.7,
51
  max_tokens=256
@@ -70,7 +83,7 @@ def generate_image(prompt: str) -> dict:
70
  if not prompt:
71
  prompt = "a ginger cat wearing a tiny wizard hat reading a spellbook"
72
  try:
73
- token = get_hf_token()
74
  client = InferenceClient(
75
  provider="auto",
76
  api_key=token,
@@ -125,6 +138,10 @@ class LockedWorkflow(gr.Workflow):
125
  bound = self._bound
126
 
127
  def call_fn(data, _token=None) -> str:
 
 
 
 
128
  fn_name = data[0] if data else ""
129
  try:
130
  args_json = data[1] if len(data) > 1 else "[]"
 
4
  from huggingface_hub import InferenceClient
5
  from huggingface_hub import get_token as hf_get_token
6
  from gradio.context import LocalContext
7
+ import contextvars
8
+
9
+ workflow_token = contextvars.ContextVar("workflow_token", default=None)
10
+
11
 
12
  def get_hf_token() -> str | None:
13
  """
14
+ Retrieves the HF API token from either the workflow context,
15
+ the user's Gradio OAuth session, or falls back to the system environment.
16
  """
17
+ w_token = workflow_token.get()
18
+ if w_token:
19
+ return w_token
20
+
21
  request = LocalContext.request.get(None)
22
  if request is not None:
23
  session = getattr(request, "session", {})
 
39
  if not concept:
40
  return "a ginger cat wearing a tiny wizard hat reading a spellbook"
41
  try:
42
+ token = get_hf_token() or os.environ.get("HF_TOKEN") or os.environ.get("HF_API_TOKEN")
43
+ client = InferenceClient(
44
+ provider="together",
45
+ api_key=token,
46
+ bill_to="huggingface",
47
+ )
48
  system_instruction = (
49
  "You are an expert prompt engineer for text-to-image models. "
50
  "Your task is to take a simple concept and expand it into a detailed, "
 
58
  {"role": "user", "content": f"Concept: {concept}"}
59
  ]
60
  response = client.chat_completion(
61
+ model="nvidia/NVIDIA-Nemotron-3-Ultra-550B-A55B-NVFP4",
62
  messages=messages,
63
  temperature=0.7,
64
  max_tokens=256
 
83
  if not prompt:
84
  prompt = "a ginger cat wearing a tiny wizard hat reading a spellbook"
85
  try:
86
+ token = get_hf_token() or os.environ.get("HF_TOKEN") or os.environ.get("HF_API_TOKEN")
87
  client = InferenceClient(
88
  provider="auto",
89
  api_key=token,
 
138
  bound = self._bound
139
 
140
  def call_fn(data, _token=None) -> str:
141
+ if _token:
142
+ t_str = _token.token if hasattr(_token, "token") else _token
143
+ if t_str:
144
+ workflow_token.set(str(t_str))
145
  fn_name = data[0] if data else ""
146
  try:
147
  args_json = data[1] if len(data) > 1 else "[]"