"""
KnowLedge Inference — inference.py
Requires transformers>=5.0.0. Gemma 4 (E4B) is a multimodal model whose
architecture is Gemma4ForConditionalGeneration (added in transformers 5.x).
ZeroGPU pattern
---------------
* Model loaded on CPU at startup with torch_dtype=torch.bfloat16 (~8 GB RAM).
* Each @spaces.GPU() call gets a real CUDA device; we move model → cuda there,
run generation, then move back → cpu and empty the cache so the next request
can reuse the shared GPU slot.
* health_check() is lightweight — it never triggers model download.
"""
import base64, io, logging, sys, re
logger = logging.getLogger(__name__)
try:
import spaces
except ModuleNotFoundError:
class _SpacesStub:
@staticmethod
def GPU(duration: int = 60):
def decorator(fn): return fn
return decorator
spaces = _SpacesStub()
_import_error = None
HAS_TRANSFORMERS = False
Gemma4Processor = None
AutoModelForCausalLM = None
torch = None
try:
from transformers import Gemma4Processor, AutoModelForCausalLM
import torch
HAS_TRANSFORMERS = True
except Exception as _e:
_import_error = f"{type(_e).__name__}: {_e}"
try:
from PIL import Image
HAS_PIL = True
except ImportError:
HAS_PIL = False
MODEL_REPO = "unsloth/gemma-4-E4B-it"
_processor = None
_model = None
def load_model():
"""Load model on CPU with bfloat16 (lazy, once only)."""
global _processor, _model
if _model is not None:
return _processor, _model
if not HAS_TRANSFORMERS:
raise RuntimeError(f"transformers import failed: {_import_error}")
logger.info("Loading %s on CPU (bfloat16)…", MODEL_REPO)
_processor = Gemma4Processor.from_pretrained(MODEL_REPO)
_model = AutoModelForCausalLM.from_pretrained(
MODEL_REPO,
torch_dtype=torch.bfloat16, # ~8 GB — NOT dtype= which is silently ignored
low_cpu_mem_usage=True, # stream weights, lower peak RAM
)
_model.eval()
logger.info("Model loaded. Device: %s", next(_model.parameters()).device)
return _processor, _model
@spaces.GPU(duration=90)
def generate_text(model_name: str, prompt: str, max_new_tokens: int = 512) -> str:
processor, model = load_model()
device = "cuda" if torch.cuda.is_available() else "cpu"
logger.info("generate_text: device=%s cuda_available=%s", device, torch.cuda.is_available())
model.to(device)
try:
messages = [
{"role": "system", "content": "You are a helpful educational assistant."},
{"role": "user", "content": prompt},
]
try:
text = processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True, enable_thinking=False
)
inputs = processor(text=text, return_tensors="pt").to(device)
except TypeError:
text = processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
inputs = processor(text, return_tensors="pt").to(device)
input_len = inputs["input_ids"].shape[-1]
eos = (
processor.tokenizer.eos_token_id
if hasattr(processor, "tokenizer")
else processor.eos_token_id
)
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=max_new_tokens,
temperature=1.0, top_p=0.95, top_k=64, do_sample=True,
pad_token_id=eos,
)
raw = processor.decode(outputs[0][input_len:], skip_special_tokens=True)
return re.sub(r".*?", "", raw, flags=re.DOTALL).strip()
finally:
model.to("cpu")
if device == "cuda" and torch is not None:
torch.cuda.empty_cache()
@spaces.GPU(duration=120)
def generate_with_image(prompt: str, image_base64: str, max_new_tokens: int = 512) -> str:
processor, model = load_model()
device = "cuda" if torch.cuda.is_available() else "cpu"
logger.info("generate_with_image: device=%s cuda_available=%s", device, torch.cuda.is_available())
model.to(device)
try:
try:
image = Image.open(io.BytesIO(base64.b64decode(image_base64))).convert("RGB")
except Exception:
return generate_text("e4b", prompt, max_new_tokens)
try:
messages = [{"role": "user", "content": [
{"type": "image", "image": image},
{"type": "text", "text": prompt},
]}]
inputs = processor.apply_chat_template(
messages, tokenize=True, return_dict=True,
return_tensors="pt", add_generation_prompt=True,
)
inputs = {k: v.to(device) for k, v in inputs.items()}
except Exception:
return generate_text("e4b", prompt, max_new_tokens)
input_len = inputs["input_ids"].shape[-1]
with torch.no_grad():
outputs = model.generate(
**inputs, max_new_tokens=max_new_tokens,
temperature=1.0, top_p=0.95, top_k=64, do_sample=True,
)
raw = processor.decode(outputs[0][input_len:], skip_special_tokens=True)
return re.sub(r".*?", "", raw, flags=re.DOTALL).strip()
finally:
model.to("cpu")
if device == "cuda" and torch is not None:
torch.cuda.empty_cache()
def health_check() -> dict:
"""Lightweight health check — does NOT trigger model download."""
import transformers
has_cuda = torch.cuda.is_available() if torch is not None else False
return {
"backend": "transformers",
"transformers_version": transformers.__version__,
"model_repo": MODEL_REPO,
"has_transformers": HAS_TRANSFORMERS,
"import_error": _import_error,
"python": sys.version,
"cuda_available": has_cuda,
"model_loaded": _model is not None,
"status": "ok" if HAS_TRANSFORMERS else "error",
"device": str(next(_model.parameters()).device) if _model is not None else "not_loaded_yet",
}