Spaces:
Sleeping
Sleeping
| import os | |
| import re | |
| import threading | |
| from pathlib import Path | |
| from typing import Any | |
| from pydantic import BaseModel | |
| from environment.actions import EpisodeObservation | |
| MODEL_ID = os.getenv("MODEL_ID", "Siddh12334/qwen-1.5b-context-corruption") | |
| MAX_NEW_TOKENS = int(os.getenv("MODEL_MAX_NEW_TOKENS", "128")) | |
| ENABLE_TRAINED_MODEL = os.getenv("ENABLE_TRAINED_MODEL", "true").lower() not in {"0", "false", "no"} | |
| _LOCK = threading.Lock() | |
| _MODEL = None | |
| _TOKENIZER = None | |
| _CACHE_DIR = None | |
| _LOAD_STARTED = False | |
| _LOAD_ERROR = None | |
| class InferenceRequest(BaseModel): | |
| observation: EpisodeObservation | |
| class InferenceResponse(BaseModel): | |
| text: str | |
| loaded_model: str | |
| mode: str = "trained" | |
| model_ready: bool = True | |
| def configure_runtime_dirs() -> Path: | |
| root = Path(os.getenv("MODEL_RUNTIME_DIR", "/tmp/context-corruption-model")) | |
| cache = root / "cache" | |
| env_dirs = { | |
| "HOME": root, | |
| "XDG_CACHE_HOME": cache, | |
| "HF_HOME": cache / "huggingface", | |
| "HF_HUB_CACHE": cache / "huggingface" / "hub", | |
| "TRANSFORMERS_CACHE": cache / "huggingface" / "transformers", | |
| } | |
| for path in env_dirs.values(): | |
| path.mkdir(parents=True, exist_ok=True) | |
| for key, path in env_dirs.items(): | |
| os.environ[key] = str(path) | |
| return env_dirs["HF_HUB_CACHE"] | |
| _CACHE_DIR = configure_runtime_dirs() | |
| def model_status() -> dict[str, Any]: | |
| torch = _import_torch() | |
| return { | |
| "model_id": MODEL_ID, | |
| "loaded": _MODEL is not None, | |
| "loading": _LOAD_STARTED and _MODEL is None and _LOAD_ERROR is None, | |
| "load_error": _LOAD_ERROR, | |
| "enabled": ENABLE_TRAINED_MODEL, | |
| "cuda_available": torch.cuda.is_available(), | |
| "cuda_device": torch.cuda.get_device_name(0) if torch.cuda.is_available() else None, | |
| } | |
| def _import_torch(): | |
| import torch | |
| return torch | |
| def warm_model_async() -> None: | |
| global _LOAD_STARTED | |
| if not ENABLE_TRAINED_MODEL or _MODEL is not None or _LOAD_STARTED: | |
| return | |
| _LOAD_STARTED = True | |
| thread = threading.Thread(target=_load_model_safely, name="model-loader", daemon=True) | |
| thread.start() | |
| def _load_model_safely(): | |
| global _LOAD_ERROR | |
| try: | |
| _load_model() | |
| except Exception as exc: # pragma: no cover - depends on remote model/runtime. | |
| _LOAD_ERROR = str(exc) | |
| def _load_model(): | |
| global _MODEL, _TOKENIZER, _LOAD_ERROR, _LOAD_STARTED | |
| if _MODEL is not None and _TOKENIZER is not None: | |
| return _MODEL, _TOKENIZER | |
| with _LOCK: | |
| if _MODEL is not None and _TOKENIZER is not None: | |
| return _MODEL, _TOKENIZER | |
| _LOAD_STARTED = True | |
| _LOAD_ERROR = None | |
| torch = _import_torch() | |
| from peft import PeftConfig, PeftModel | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig | |
| quantization_config = None | |
| if torch.cuda.is_available(): | |
| quantization_config = BitsAndBytesConfig( | |
| load_in_4bit=True, | |
| bnb_4bit_compute_dtype=torch.bfloat16, | |
| bnb_4bit_quant_type="nf4", | |
| bnb_4bit_use_double_quant=True, | |
| ) | |
| from huggingface_hub import snapshot_download | |
| adapter_path = snapshot_download( | |
| repo_id=MODEL_ID, | |
| cache_dir=str(_CACHE_DIR), | |
| ) | |
| peft_config = PeftConfig.from_pretrained(adapter_path) | |
| base_path = snapshot_download( | |
| repo_id=peft_config.base_model_name_or_path, | |
| cache_dir=str(_CACHE_DIR), | |
| allow_patterns=[ | |
| "*.json", | |
| "*.safetensors", | |
| "*.model", | |
| "*.txt", | |
| "*.jinja", | |
| ], | |
| ) | |
| base_model = AutoModelForCausalLM.from_pretrained( | |
| base_path, | |
| device_map="auto" if torch.cuda.is_available() else "cpu", | |
| torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32, | |
| quantization_config=quantization_config, | |
| low_cpu_mem_usage=True, | |
| ) | |
| _TOKENIZER = AutoTokenizer.from_pretrained(adapter_path) | |
| _MODEL = PeftModel.from_pretrained( | |
| base_model, | |
| adapter_path, | |
| ) | |
| _MODEL.eval() | |
| return _MODEL, _TOKENIZER | |
| def _format_prompt(observation: EpisodeObservation) -> list[dict[str, str]]: | |
| docs_text = "\n\n".join( | |
| f"[Doc {doc.id}] {doc.title}\n{doc.content}" for doc in observation.documents | |
| ) | |
| system = ( | |
| "You are an epistemic agent. Answer the question and identify corrupted documents. " | |
| 'Respond ONLY as JSON: {"answer": "...", "suspicious_docs": [0], "confidence": 0.8}' | |
| ) | |
| user = f"Question: {observation.question}\n\nDocuments:\n{docs_text}" | |
| return [ | |
| {"role": "system", "content": system}, | |
| {"role": "user", "content": user}, | |
| ] | |
| def run_inference(observation: EpisodeObservation) -> InferenceResponse: | |
| if not ENABLE_TRAINED_MODEL: | |
| return _fast_inference(observation, "Trained model loading is disabled.") | |
| if _MODEL is None or _TOKENIZER is None: | |
| warm_model_async() | |
| return _fast_inference( | |
| observation, | |
| "Trained model is warming up. Returning a fast heuristic response for the demo.", | |
| ) | |
| model, tokenizer = _load_model() | |
| torch = _import_torch() | |
| messages = _format_prompt(observation) | |
| prompt = tokenizer.apply_chat_template( | |
| messages, | |
| tokenize=False, | |
| add_generation_prompt=True, | |
| ) | |
| inputs = tokenizer(prompt, return_tensors="pt").to(model.device) | |
| with torch.inference_mode(): | |
| output_ids = model.generate( | |
| **inputs, | |
| max_new_tokens=MAX_NEW_TOKENS, | |
| do_sample=False, | |
| pad_token_id=tokenizer.eos_token_id, | |
| ) | |
| generated_ids = output_ids[0][inputs["input_ids"].shape[-1]:] | |
| text = tokenizer.decode(generated_ids, skip_special_tokens=True).strip() | |
| return InferenceResponse(text=text, loaded_model=MODEL_ID) | |
| def _fast_inference(observation: EpisodeObservation, note: str) -> InferenceResponse: | |
| docs = observation.documents | |
| flagged = set(observation.flagged_ids) | |
| candidates = [] | |
| for doc in docs: | |
| if doc.id in flagged: | |
| continue | |
| candidate = _extract_candidate_answer(observation.question, doc.content) | |
| if candidate: | |
| candidates.append(candidate) | |
| answer = _majority_vote(candidates) or "unknown" | |
| suspicious_docs = sorted(flagged) | |
| text = ( | |
| "{" | |
| f'"answer": "{_json_escape(answer)}", ' | |
| f'"suspicious_docs": {suspicious_docs}, ' | |
| '"confidence": 0.55, ' | |
| f'"note": "{_json_escape(note)}"' | |
| "}" | |
| ) | |
| return InferenceResponse( | |
| text=text, | |
| loaded_model=MODEL_ID, | |
| mode="heuristic", | |
| model_ready=False, | |
| ) | |
| def _extract_candidate_answer(question: str, content: str) -> str | None: | |
| patterns = [ | |
| r"\banswer\s+is\s+([^.;,\n]+)", | |
| r"\banswer\s+remains\s+([^.;,\n]+)", | |
| r"\brecords\s+([^.;,\n]+)\s+as\s+the\s+answer", | |
| r"\bconfirms\s+that\s+([^.;,\n]+)", | |
| ] | |
| for pattern in patterns: | |
| match = re.search(pattern, content, flags=re.IGNORECASE) | |
| if match: | |
| return match.group(1).strip(" '\"") | |
| quoted = re.findall(r'"([^"]{2,80})"', content) | |
| if quoted: | |
| return quoted[-1].strip() | |
| # Last resort: use a short proper-noun span that is not just copied from the question. | |
| question_terms = set(re.findall(r"[A-Z][a-z]+", question)) | |
| spans = re.findall(r"\b[A-Z][a-z]+(?:\s+[A-Z][a-z]+){0,3}\b", content) | |
| for span in reversed(spans): | |
| if span not in question_terms and len(span.split()) <= 4: | |
| return span.strip() | |
| return None | |
| def _majority_vote(candidates: list[str]) -> str | None: | |
| if not candidates: | |
| return None | |
| counts: dict[str, tuple[int, str]] = {} | |
| for candidate in candidates: | |
| key = re.sub(r"\W+", " ", candidate).strip().lower() | |
| if not key: | |
| continue | |
| count, original = counts.get(key, (0, candidate)) | |
| counts[key] = (count + 1, original) | |
| if not counts: | |
| return candidates[0] | |
| return max(counts.values(), key=lambda item: item[0])[1] | |
| def _json_escape(value: str) -> str: | |
| return value.replace("\\", "\\\\").replace('"', '\\"') | |