workbykait commited on
Commit
6f59ec0
·
verified ·
1 Parent(s): 53f9d55

Update inference.py

Browse files
Files changed (1) hide show
  1. inference.py +34 -51
inference.py CHANGED
@@ -7,59 +7,42 @@ import gc
7
 
8
  def generate_response(model_cfg, prompt, max_new_tokens=512, temperature=0.7):
9
  model_id = model_cfg["id"]
10
- provider = model_cfg.get("provider", None) # optional override
11
 
12
- client = InferenceClient(model=model_id, provider=provider)
 
13
 
14
- try:
15
- # Prefer chat/completions style — much more reliable in 2026
16
- messages = [{"role": "user", "content": prompt}]
17
-
18
- completion = client.chat_completion(
19
- messages=messages,
20
- max_tokens=max_new_tokens,
21
- temperature=temperature,
22
- stream=False
23
- )
24
- return completion.choices[0].message.content.strip()
25
-
26
- except AttributeError:
27
- # Fallback to text_generation if chat_completion not available
28
  try:
29
- output = client.text_generation(
30
- prompt,
31
- max_new_tokens=max_new_tokens,
 
 
32
  temperature=temperature,
33
- details=False
34
  )
35
- return output.generated_text.strip() if hasattr(output, "generated_text") else output
36
- except Exception as e_text:
37
- raise RuntimeError(f"Both chat_completion and text_generation failed: {e_text}")
38
-
39
- except Exception as e:
40
- raise RuntimeError(
41
- f"Generation failed for {model_id} (provider={provider}): {str(e)}\n"
42
- "Try changing provider in models_config.py or use a different model."
43
- )
44
-
45
- # Keep local quantized fallback only if you have GPU hardware
46
- # (comment out if running on CPU-only Space)
47
- def local_generate_fallback(model_cfg, prompt, max_new_tokens=512):
48
- if not model_cfg.get("quantized", False):
49
- return None
50
-
51
- try:
52
- bnb_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16)
53
- tokenizer = AutoTokenizer.from_pretrained(model_cfg["id"])
54
- model = AutoModelForCausalLM.from_pretrained(
55
- model_cfg["id"], quantization_config=bnb_config, device_map="auto"
56
- )
57
- inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
58
- outputs = model.generate(**inputs, max_new_tokens=max_new_tokens, do_sample=True, temperature=0.7)
59
- resp = tokenizer.decode(outputs[0], skip_special_tokens=True)
60
- del model, tokenizer, inputs, outputs
61
- gc.collect()
62
- torch.cuda.empty_cache() if torch.cuda.is_available() else None
63
- return resp[len(prompt):].strip()
64
- except Exception as e:
65
- return f"[Local fallback failed: {str(e)}]"
 
7
 
8
  def generate_response(model_cfg, prompt, max_new_tokens=512, temperature=0.7):
9
  model_id = model_cfg["id"]
10
+ primary_provider = model_cfg.get("provider")
11
 
12
+ # Try order: primary → groq → nebius → featherless-ai → default (HF)
13
+ providers_to_try = [primary_provider, "groq", "nebius", "featherless-ai", None]
14
 
15
+ for prov in [p for p in providers_to_try if p is not None or p == primary_provider]:
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  try:
17
+ client = InferenceClient(model=model_id, provider=prov)
18
+ messages = [{"role": "user", "content": prompt}]
19
+ completion = client.chat.completions.create(
20
+ messages=messages,
21
+ max_tokens=max_new_tokens,
22
  temperature=temperature,
23
+ stream=False
24
  )
25
+ return completion.choices[0].message.content.strip()
26
+
27
+ except Exception as chat_err:
28
+ print(f"Chat completion failed (provider={prov}): {chat_err}")
29
+ # Fallback to legacy text_generation
30
+ try:
31
+ output = client.text_generation(
32
+ prompt,
33
+ max_new_tokens=max_new_tokens,
34
+ temperature=temperature,
35
+ details=False
36
+ )
37
+ return output if isinstance(output, str) else output.generated_text
38
+ except Exception as text_err:
39
+ print(f"Text generation also failed (provider={prov}): {text_err}")
40
+ continue
41
+
42
+ raise RuntimeError(
43
+ f"Generation failed for {model_id} after trying providers: {providers_to_try}\n"
44
+ "Check model card for supported providers or try different models."
45
+ )
46
+
47
+ # Optional local quantized fallback (only if GPU hardware available)
48
+ # ... (keep your existing local code if needed)