Spaces:
Running
Running
File size: 2,697 Bytes
a8fdab7 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 | import os
from config import is_running_in_spaces
_selected_model: str | None = None
def _use_hf_backend() -> bool:
"""Use HF Inference API when running on Spaces or when HF_TOKEN is set and Ollama is absent."""
if is_running_in_spaces():
return True
if os.environ.get("HF_TOKEN") and not os.environ.get("OLLAMA_BASE_URL"):
return True
return False
# ---------------------------------------------------------------------------
# HF Inference API backend
# ---------------------------------------------------------------------------
def _hf_client():
from huggingface_hub import InferenceClient
token = os.environ.get("HF_TOKEN", "")
return InferenceClient(token=token)
def _hf_list_models() -> list[str]:
return [
"meta-llama/Llama-3.1-8B-Instruct",
"mistralai/Mistral-7B-Instruct-v0.3",
"google/gemma-2-9b-it",
]
def _hf_generate_text(prompt: str, model: str) -> str:
response = _hf_client().chat_completion(
model=model,
messages=[{"role": "user", "content": prompt}],
max_tokens=2048,
)
return response.choices[0].message.content.strip()
# ---------------------------------------------------------------------------
# Ollama backend (original)
# ---------------------------------------------------------------------------
def _ollama_client():
import ollama
from config import get_ollama_base_url
return ollama.Client(host=get_ollama_base_url())
def _ollama_list_models() -> list[str]:
response = _ollama_client().list()
return sorted(m.model for m in response.models)
def _ollama_generate_text(prompt: str, model: str) -> str:
response = _ollama_client().chat(
model=model,
messages=[{"role": "user", "content": prompt}],
)
return response["message"]["content"].strip()
# ---------------------------------------------------------------------------
# Public API (unchanged interface)
# ---------------------------------------------------------------------------
def list_models() -> list[str]:
if _use_hf_backend():
return _hf_list_models()
return _ollama_list_models()
def select_model(model: str) -> None:
global _selected_model
_selected_model = model
def get_active_model() -> str | None:
return _selected_model
def generate_text(prompt: str, model_name: str = None) -> str:
model = model_name or _selected_model
if not model:
raise RuntimeError(
"No model selected. Call select_model() first or pass model_name."
)
if _use_hf_backend():
return _hf_generate_text(prompt, model)
return _ollama_generate_text(prompt, model)
|