context-corruption-env / environment /model_inference.py
aagparekh's picture
Add interactive frontend UI
b0c701c
import os
import re
import threading
from pathlib import Path
from typing import Any
from pydantic import BaseModel
from environment.actions import EpisodeObservation
MODEL_ID = os.getenv("MODEL_ID", "Siddh12334/qwen-1.5b-context-corruption")
MAX_NEW_TOKENS = int(os.getenv("MODEL_MAX_NEW_TOKENS", "128"))
ENABLE_TRAINED_MODEL = os.getenv("ENABLE_TRAINED_MODEL", "true").lower() not in {"0", "false", "no"}
_LOCK = threading.Lock()
_MODEL = None
_TOKENIZER = None
_CACHE_DIR = None
_LOAD_STARTED = False
_LOAD_ERROR = None
class InferenceRequest(BaseModel):
observation: EpisodeObservation
class InferenceResponse(BaseModel):
text: str
loaded_model: str
mode: str = "trained"
model_ready: bool = True
def configure_runtime_dirs() -> Path:
root = Path(os.getenv("MODEL_RUNTIME_DIR", "/tmp/context-corruption-model"))
cache = root / "cache"
env_dirs = {
"HOME": root,
"XDG_CACHE_HOME": cache,
"HF_HOME": cache / "huggingface",
"HF_HUB_CACHE": cache / "huggingface" / "hub",
"TRANSFORMERS_CACHE": cache / "huggingface" / "transformers",
}
for path in env_dirs.values():
path.mkdir(parents=True, exist_ok=True)
for key, path in env_dirs.items():
os.environ[key] = str(path)
return env_dirs["HF_HUB_CACHE"]
_CACHE_DIR = configure_runtime_dirs()
def model_status() -> dict[str, Any]:
torch = _import_torch()
return {
"model_id": MODEL_ID,
"loaded": _MODEL is not None,
"loading": _LOAD_STARTED and _MODEL is None and _LOAD_ERROR is None,
"load_error": _LOAD_ERROR,
"enabled": ENABLE_TRAINED_MODEL,
"cuda_available": torch.cuda.is_available(),
"cuda_device": torch.cuda.get_device_name(0) if torch.cuda.is_available() else None,
}
def _import_torch():
import torch
return torch
def warm_model_async() -> None:
global _LOAD_STARTED
if not ENABLE_TRAINED_MODEL or _MODEL is not None or _LOAD_STARTED:
return
_LOAD_STARTED = True
thread = threading.Thread(target=_load_model_safely, name="model-loader", daemon=True)
thread.start()
def _load_model_safely():
global _LOAD_ERROR
try:
_load_model()
except Exception as exc: # pragma: no cover - depends on remote model/runtime.
_LOAD_ERROR = str(exc)
def _load_model():
global _MODEL, _TOKENIZER, _LOAD_ERROR, _LOAD_STARTED
if _MODEL is not None and _TOKENIZER is not None:
return _MODEL, _TOKENIZER
with _LOCK:
if _MODEL is not None and _TOKENIZER is not None:
return _MODEL, _TOKENIZER
_LOAD_STARTED = True
_LOAD_ERROR = None
torch = _import_torch()
from peft import PeftConfig, PeftModel
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
quantization_config = None
if torch.cuda.is_available():
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_quant_type="nf4",
bnb_4bit_use_double_quant=True,
)
from huggingface_hub import snapshot_download
adapter_path = snapshot_download(
repo_id=MODEL_ID,
cache_dir=str(_CACHE_DIR),
)
peft_config = PeftConfig.from_pretrained(adapter_path)
base_path = snapshot_download(
repo_id=peft_config.base_model_name_or_path,
cache_dir=str(_CACHE_DIR),
allow_patterns=[
"*.json",
"*.safetensors",
"*.model",
"*.txt",
"*.jinja",
],
)
base_model = AutoModelForCausalLM.from_pretrained(
base_path,
device_map="auto" if torch.cuda.is_available() else "cpu",
torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
quantization_config=quantization_config,
low_cpu_mem_usage=True,
)
_TOKENIZER = AutoTokenizer.from_pretrained(adapter_path)
_MODEL = PeftModel.from_pretrained(
base_model,
adapter_path,
)
_MODEL.eval()
return _MODEL, _TOKENIZER
def _format_prompt(observation: EpisodeObservation) -> list[dict[str, str]]:
docs_text = "\n\n".join(
f"[Doc {doc.id}] {doc.title}\n{doc.content}" for doc in observation.documents
)
system = (
"You are an epistemic agent. Answer the question and identify corrupted documents. "
'Respond ONLY as JSON: {"answer": "...", "suspicious_docs": [0], "confidence": 0.8}'
)
user = f"Question: {observation.question}\n\nDocuments:\n{docs_text}"
return [
{"role": "system", "content": system},
{"role": "user", "content": user},
]
def run_inference(observation: EpisodeObservation) -> InferenceResponse:
if not ENABLE_TRAINED_MODEL:
return _fast_inference(observation, "Trained model loading is disabled.")
if _MODEL is None or _TOKENIZER is None:
warm_model_async()
return _fast_inference(
observation,
"Trained model is warming up. Returning a fast heuristic response for the demo.",
)
model, tokenizer = _load_model()
torch = _import_torch()
messages = _format_prompt(observation)
prompt = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
)
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
with torch.inference_mode():
output_ids = model.generate(
**inputs,
max_new_tokens=MAX_NEW_TOKENS,
do_sample=False,
pad_token_id=tokenizer.eos_token_id,
)
generated_ids = output_ids[0][inputs["input_ids"].shape[-1]:]
text = tokenizer.decode(generated_ids, skip_special_tokens=True).strip()
return InferenceResponse(text=text, loaded_model=MODEL_ID)
def _fast_inference(observation: EpisodeObservation, note: str) -> InferenceResponse:
docs = observation.documents
flagged = set(observation.flagged_ids)
candidates = []
for doc in docs:
if doc.id in flagged:
continue
candidate = _extract_candidate_answer(observation.question, doc.content)
if candidate:
candidates.append(candidate)
answer = _majority_vote(candidates) or "unknown"
suspicious_docs = sorted(flagged)
text = (
"{"
f'"answer": "{_json_escape(answer)}", '
f'"suspicious_docs": {suspicious_docs}, '
'"confidence": 0.55, '
f'"note": "{_json_escape(note)}"'
"}"
)
return InferenceResponse(
text=text,
loaded_model=MODEL_ID,
mode="heuristic",
model_ready=False,
)
def _extract_candidate_answer(question: str, content: str) -> str | None:
patterns = [
r"\banswer\s+is\s+([^.;,\n]+)",
r"\banswer\s+remains\s+([^.;,\n]+)",
r"\brecords\s+([^.;,\n]+)\s+as\s+the\s+answer",
r"\bconfirms\s+that\s+([^.;,\n]+)",
]
for pattern in patterns:
match = re.search(pattern, content, flags=re.IGNORECASE)
if match:
return match.group(1).strip(" '\"")
quoted = re.findall(r'"([^"]{2,80})"', content)
if quoted:
return quoted[-1].strip()
# Last resort: use a short proper-noun span that is not just copied from the question.
question_terms = set(re.findall(r"[A-Z][a-z]+", question))
spans = re.findall(r"\b[A-Z][a-z]+(?:\s+[A-Z][a-z]+){0,3}\b", content)
for span in reversed(spans):
if span not in question_terms and len(span.split()) <= 4:
return span.strip()
return None
def _majority_vote(candidates: list[str]) -> str | None:
if not candidates:
return None
counts: dict[str, tuple[int, str]] = {}
for candidate in candidates:
key = re.sub(r"\W+", " ", candidate).strip().lower()
if not key:
continue
count, original = counts.get(key, (0, candidate))
counts[key] = (count + 1, original)
if not counts:
return candidates[0]
return max(counts.values(), key=lambda item: item[0])[1]
def _json_escape(value: str) -> str:
return value.replace("\\", "\\\\").replace('"', '\\"')