youssefreda9 commited on
Commit
381091e
·
1 Parent(s): f8422c2

debug: Deep diagnostics to find InferenceClient internal transport

Browse files
Files changed (1) hide show
  1. src/hf_inference.py +126 -109
src/hf_inference.py CHANGED
@@ -2,13 +2,7 @@
2
  HuggingFace Inference API client for Bayan models.
3
 
4
  Uses huggingface_hub.InferenceClient which routes through HF's internal
5
- network when running inside HF Spaces (bypasses external DNS).
6
-
7
- Models:
8
- - bayan10/summarization-model (MBart, summarization pipeline)
9
- - bayan10/AraSpell-Model (spelling correction)
10
- - bayan10/PuncAra-v1 (punctuation, encoder-decoder)
11
- - bayan10/AutoComplete (text generation / fill-mask)
12
  """
13
 
14
  import os
@@ -20,168 +14,191 @@ logger = logging.getLogger(__name__)
20
 
21
  HF_API_TOKEN = os.environ.get("HF_API_TOKEN", "")
22
 
23
- # Lazy-initialized client
24
  _client = None
25
 
26
 
27
  def _get_client():
28
- """Get or create the InferenceClient singleton."""
29
  global _client
30
  if _client is None:
31
  from huggingface_hub import InferenceClient
32
  _client = InferenceClient(token=HF_API_TOKEN if HF_API_TOKEN else None)
33
- logger.info("InferenceClient initialized (token=%s)", "set" if HF_API_TOKEN else "not set")
34
  return _client
35
 
36
 
37
- # ============================================================
38
  # Repository IDs
39
- # ============================================================
40
-
41
  SUMMARIZATION_REPO = os.environ.get("SUMMARIZATION_REPO_ID", "bayan10/summarization-model")
42
  SPELLING_REPO = os.environ.get("SPELLING_REPO_ID", "bayan10/AraSpell-Model")
43
  PUNCTUATION_REPO = os.environ.get("PUNCTUATION_REPO_ID", "bayan10/PuncAra-v1")
44
  AUTOCOMPLETE_REPO = os.environ.get("AUTOCOMPLETE_REPO_ID", "bayan10/AutoComplete")
45
 
46
 
47
- # ============================================================
48
- # Model-specific wrappers using InferenceClient typed methods
49
- # ============================================================
50
-
51
- def hf_summarize(text, max_length=128, min_length=30):
52
- """Summarize Arabic text via HF Inference API."""
53
  client = _get_client()
54
- logger.info("Calling summarization: %s", SUMMARIZATION_REPO)
55
 
56
- result = client.summarization(text, model=SUMMARIZATION_REPO)
 
 
 
 
57
 
58
- logger.info("Summarization result: %s %s", type(result).__name__, str(result)[:150])
 
 
 
 
 
59
 
60
- # SummarizationOutput has .summary_text
61
- if hasattr(result, "summary_text"):
62
- return result.summary_text
63
- if isinstance(result, dict):
64
- return result.get("summary_text", result.get("generated_text", str(result)))
65
- return str(result)
 
 
66
 
 
 
67
 
68
- def hf_correct_spelling(text):
69
- """Correct spelling in Arabic text via HF Inference API."""
70
- client = _get_client()
71
- logger.info("Calling spelling: %s", SPELLING_REPO)
72
 
73
- # Try text2text_generation first (for seq2seq models), fall back to text_generation
74
- try:
75
- result = client.text2text_generation(text, model=SPELLING_REPO)
76
- logger.info("Spelling result (t2t): %s — %s", type(result).__name__, str(result)[:150])
77
- if hasattr(result, "generated_text"):
78
- return result.generated_text
79
- if isinstance(result, str):
80
- return result if result.strip() else text
81
- if isinstance(result, dict):
82
- return result.get("generated_text", text)
83
- return text
84
- except Exception as e1:
85
- logger.warning("text2text_generation failed for spelling: %s", repr(e1)[:200])
86
- try:
87
- result = client.text_generation(text, model=SPELLING_REPO, max_new_tokens=len(text) + 50)
88
- logger.info("Spelling result (tg): %s — %s", type(result).__name__, str(result)[:150])
89
- if isinstance(result, str):
90
- return result if result.strip() else text
91
- return text
92
- except Exception as e2:
93
- logger.error("text_generation also failed for spelling: %s", repr(e2)[:200])
94
- raise
95
 
 
 
96
 
97
- def hf_add_punctuation(text):
98
- """Add punctuation to Arabic text via HF Inference API."""
99
- client = _get_client()
100
- logger.info("Calling punctuation: %s", PUNCTUATION_REPO)
101
 
102
- try:
103
- result = client.text2text_generation(text, model=PUNCTUATION_REPO)
104
- logger.info("Punctuation result (t2t): %s — %s", type(result).__name__, str(result)[:150])
105
- if hasattr(result, "generated_text"):
106
- return result.generated_text
107
- if isinstance(result, str):
108
- return result if result.strip() else text
109
- if isinstance(result, dict):
110
- return result.get("generated_text", text)
111
- return text
112
- except Exception as e1:
113
- logger.warning("text2text_generation failed for punctuation: %s", repr(e1)[:200])
114
- try:
115
- result = client.text_generation(text, model=PUNCTUATION_REPO, max_new_tokens=len(text) + 50)
116
- logger.info("Punctuation result (tg): %s — %s", type(result).__name__, str(result)[:150])
117
- if isinstance(result, str):
118
- return result if result.strip() else text
119
- return text
120
- except Exception as e2:
121
- logger.error("text_generation also failed for punctuation: %s", repr(e2)[:200])
122
- raise
123
 
 
 
 
124
 
125
- def hf_autocomplete(text, n=5):
126
- """Get autocomplete suggestions for Arabic text via HF Inference API."""
127
- client = _get_client()
128
- logger.info("Calling autocomplete: %s", AUTOCOMPLETE_REPO)
 
 
 
 
 
 
 
 
 
 
 
 
129
 
130
- result = client.text_generation(text, model=AUTOCOMPLETE_REPO, max_new_tokens=20)
131
 
132
- logger.info("Autocomplete result: %s — %s", type(result).__name__, str(result)[:150])
 
 
 
 
 
 
 
133
 
134
- if isinstance(result, str):
135
- completion = result[len(text):].strip() if result.startswith(text) else result
136
- return [completion] if completion else [text]
137
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
138
  return [text]
139
 
140
 
141
  def check_hf_api_available():
142
- """Quick check if HF Inference API is reachable."""
143
  try:
144
- client = _get_client()
145
- return client is not None
146
  except Exception:
147
  return False
148
 
149
 
150
  def debug_test_all_models():
151
- """
152
- Test all HF models and return results dict.
153
- Also includes diagnostic info about InferenceClient internals.
154
- """
155
  results = {}
156
  test_text = "هذا نص تجريبي للاختبار"
157
  long_text = (test_text + " ") * 5
158
 
159
- # Diagnostic info
160
  try:
161
  client = _get_client()
 
 
 
 
 
 
 
162
  diag = {
163
- "client_type": type(client).__name__,
164
- "api_url": getattr(client, "api_url", "N/A"),
165
- "base_url": getattr(client, "base_url", "N/A"),
166
- "model": getattr(client, "model", "N/A"),
 
167
  }
168
- # Check available methods
169
- diag["has_post"] = hasattr(client, "post")
170
- diag["has_text2text"] = hasattr(client, "text2text_generation")
171
- diag["has_summarization"] = hasattr(client, "summarization")
172
- diag["has_text_generation"] = hasattr(client, "text_generation")
 
 
 
173
  except Exception as e:
174
- diag = {"error": repr(e)[:200]}
175
 
176
  results["_diagnostics"] = diag
177
-
178
- # Test env vars
179
  results["_env"] = {
180
  "HF_INFERENCE_ENDPOINT": os.environ.get("HF_INFERENCE_ENDPOINT", "NOT SET"),
181
- "HF_API_URL": os.environ.get("HF_API_URL", "NOT SET"),
182
  "SPACE_ID": os.environ.get("SPACE_ID", "NOT SET"),
183
  }
184
 
 
185
  for name, fn, args in [
186
  ("summarization", hf_summarize, (long_text, 30, 10)),
187
  ("spelling", hf_correct_spelling, (test_text,)),
 
2
  HuggingFace Inference API client for Bayan models.
3
 
4
  Uses huggingface_hub.InferenceClient which routes through HF's internal
5
+ network when running inside HF Spaces.
 
 
 
 
 
 
6
  """
7
 
8
  import os
 
14
 
15
  HF_API_TOKEN = os.environ.get("HF_API_TOKEN", "")
16
 
 
17
  _client = None
18
 
19
 
20
  def _get_client():
 
21
  global _client
22
  if _client is None:
23
  from huggingface_hub import InferenceClient
24
  _client = InferenceClient(token=HF_API_TOKEN if HF_API_TOKEN else None)
25
+ logger.info("InferenceClient initialized")
26
  return _client
27
 
28
 
 
29
  # Repository IDs
 
 
30
  SUMMARIZATION_REPO = os.environ.get("SUMMARIZATION_REPO_ID", "bayan10/summarization-model")
31
  SPELLING_REPO = os.environ.get("SPELLING_REPO_ID", "bayan10/AraSpell-Model")
32
  PUNCTUATION_REPO = os.environ.get("PUNCTUATION_REPO_ID", "bayan10/PuncAra-v1")
33
  AUTOCOMPLETE_REPO = os.environ.get("AUTOCOMPLETE_REPO_ID", "bayan10/AutoComplete")
34
 
35
 
36
+ def _raw_inference(repo_id, payload):
37
+ """
38
+ Make a raw inference call using whatever transport InferenceClient uses.
39
+ Tries multiple approaches to find one that works.
40
+ """
 
41
  client = _get_client()
 
42
 
43
+ # Approach 1: Try the internal _post method
44
+ if hasattr(client, '_post'):
45
+ logger.info("Using client._post for %s", repo_id)
46
+ response = client._post(json=payload, model=repo_id)
47
+ return json.loads(response) if isinstance(response, (bytes, str)) else response
48
 
49
+ # Approach 2: Get the session from the client and use it directly
50
+ session = None
51
+ for attr in ['_session', 'session', '_client', 'client', '_http_client']:
52
+ if hasattr(client, attr):
53
+ session = getattr(client, attr)
54
+ break
55
 
56
+ if session and hasattr(session, 'post'):
57
+ # Find the base API URL
58
+ api_url = None
59
+ for attr in ['api_url', 'base_url', '_api_url', 'inference_url', '_base_url']:
60
+ val = getattr(client, attr, None)
61
+ if val and isinstance(val, str):
62
+ api_url = val
63
+ break
64
 
65
+ if not api_url:
66
+ api_url = "https://api-inference.huggingface.co/models"
67
 
68
+ url = api_url.rstrip('/') + '/' + repo_id if '/models' in api_url else api_url + '/models/' + repo_id
69
+ logger.info("Using session.post to %s", url)
 
 
70
 
71
+ headers = {"Content-Type": "application/json"}
72
+ if HF_API_TOKEN:
73
+ headers["Authorization"] = "Bearer " + HF_API_TOKEN
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
 
75
+ resp = session.post(url, json=payload, headers=headers, timeout=120)
76
+ return resp.json()
77
 
78
+ raise RuntimeError("No usable transport found on InferenceClient")
 
 
 
79
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
 
81
+ # ============================================================
82
+ # Model wrappers
83
+ # ============================================================
84
 
85
+ def _extract_text(result, fallback=""):
86
+ """Extract text from various HF response formats."""
87
+ if isinstance(result, list) and len(result) > 0:
88
+ item = result[0]
89
+ if isinstance(item, dict):
90
+ return (item.get("summary_text")
91
+ or item.get("generated_text")
92
+ or item.get("translation_text")
93
+ or fallback)
94
+ return str(item) if str(item).strip() else fallback
95
+ if isinstance(result, dict):
96
+ return (result.get("summary_text")
97
+ or result.get("generated_text")
98
+ or result.get("translation_text")
99
+ or fallback)
100
+ return str(result) if result else fallback
101
 
 
102
 
103
+ def hf_summarize(text, max_length=128, min_length=30):
104
+ """Summarize Arabic text."""
105
+ result = _raw_inference(SUMMARIZATION_REPO, {
106
+ "inputs": text,
107
+ "parameters": {"max_length": max_length, "min_length": min_length},
108
+ "options": {"wait_for_model": True},
109
+ })
110
+ return _extract_text(result, text[:100])
111
 
 
 
 
112
 
113
+ def hf_correct_spelling(text):
114
+ """Correct spelling in Arabic text."""
115
+ result = _raw_inference(SPELLING_REPO, {
116
+ "inputs": text,
117
+ "options": {"wait_for_model": True},
118
+ })
119
+ return _extract_text(result, text)
120
+
121
+
122
+ def hf_add_punctuation(text):
123
+ """Add punctuation to Arabic text."""
124
+ result = _raw_inference(PUNCTUATION_REPO, {
125
+ "inputs": text,
126
+ "options": {"wait_for_model": True},
127
+ })
128
+ return _extract_text(result, text)
129
+
130
+
131
+ def hf_autocomplete(text, n=5):
132
+ """Get autocomplete suggestions."""
133
+ result = _raw_inference(AUTOCOMPLETE_REPO, {
134
+ "inputs": text,
135
+ "parameters": {"max_new_tokens": 20},
136
+ "options": {"wait_for_model": True},
137
+ })
138
+
139
+ if isinstance(result, str):
140
+ c = result[len(text):].strip() if result.startswith(text) else result
141
+ return [c] if c else [text]
142
+ if isinstance(result, list):
143
+ out = []
144
+ for item in result:
145
+ g = item.get("generated_text", "") if isinstance(item, dict) else str(item)
146
+ if g.startswith(text):
147
+ g = g[len(text):].strip()
148
+ if g:
149
+ out.append(g)
150
+ return out if out else [text]
151
  return [text]
152
 
153
 
154
  def check_hf_api_available():
 
155
  try:
156
+ return _get_client() is not None
 
157
  except Exception:
158
  return False
159
 
160
 
161
  def debug_test_all_models():
162
+ """Full diagnostic debug."""
 
 
 
163
  results = {}
164
  test_text = "هذا نص تجريبي للاختبار"
165
  long_text = (test_text + " ") * 5
166
 
167
+ # Deep diagnostics
168
  try:
169
  client = _get_client()
170
+ import huggingface_hub
171
+ all_attrs = [a for a in dir(client) if not a.startswith('__')]
172
+ private_attrs = {a: str(type(getattr(client, a, None)).__name__)
173
+ for a in all_attrs if a.startswith('_') and not a.startswith('__')}
174
+ public_methods = [a for a in all_attrs if not a.startswith('_')
175
+ and callable(getattr(client, a, None))]
176
+
177
  diag = {
178
+ "hf_hub_version": getattr(huggingface_hub, '__version__', 'unknown'),
179
+ "has__post": hasattr(client, '_post'),
180
+ "has_session": hasattr(client, '_session') or hasattr(client, 'session'),
181
+ "private_attrs": private_attrs,
182
+ "public_methods": public_methods[:30],
183
  }
184
+
185
+ # Try to find internal session/transport
186
+ for attr in ['_session', 'session', '_client', 'client', '_http_client',
187
+ '_api_url', 'api_url', 'base_url', '_base_url', 'inference_url',
188
+ 'headers', '_headers']:
189
+ val = getattr(client, attr, 'NOT_FOUND')
190
+ if val != 'NOT_FOUND':
191
+ diag['found_' + attr] = str(val)[:200] if not callable(val) else 'callable'
192
  except Exception as e:
193
+ diag = {"error": repr(e)[:300]}
194
 
195
  results["_diagnostics"] = diag
 
 
196
  results["_env"] = {
197
  "HF_INFERENCE_ENDPOINT": os.environ.get("HF_INFERENCE_ENDPOINT", "NOT SET"),
 
198
  "SPACE_ID": os.environ.get("SPACE_ID", "NOT SET"),
199
  }
200
 
201
+ # Test each model
202
  for name, fn, args in [
203
  ("summarization", hf_summarize, (long_text, 30, 10)),
204
  ("spelling", hf_correct_spelling, (test_text,)),