LLM-Judge / inference.py
workbykait's picture
Update inference.py
6f59ec0 verified
# inference.py
from huggingface_hub import InferenceClient
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
import gc
def generate_response(model_cfg, prompt, max_new_tokens=512, temperature=0.7):
model_id = model_cfg["id"]
primary_provider = model_cfg.get("provider")
# Try order: primary β†’ groq β†’ nebius β†’ featherless-ai β†’ default (HF)
providers_to_try = [primary_provider, "groq", "nebius", "featherless-ai", None]
for prov in [p for p in providers_to_try if p is not None or p == primary_provider]:
try:
client = InferenceClient(model=model_id, provider=prov)
messages = [{"role": "user", "content": prompt}]
completion = client.chat.completions.create(
messages=messages,
max_tokens=max_new_tokens,
temperature=temperature,
stream=False
)
return completion.choices[0].message.content.strip()
except Exception as chat_err:
print(f"Chat completion failed (provider={prov}): {chat_err}")
# Fallback to legacy text_generation
try:
output = client.text_generation(
prompt,
max_new_tokens=max_new_tokens,
temperature=temperature,
details=False
)
return output if isinstance(output, str) else output.generated_text
except Exception as text_err:
print(f"Text generation also failed (provider={prov}): {text_err}")
continue
raise RuntimeError(
f"Generation failed for {model_id} after trying providers: {providers_to_try}\n"
"Check model card for supported providers or try different models."
)
# Optional local quantized fallback (only if GPU hardware available)
# ... (keep your existing local code if needed)