import os import re import json import time import hashlib from collections import OrderedDict from typing import Optional, Dict, Any import gradio as gr import torch from transformers import AutoTokenizer, AutoModelForCausalLM from fastapi import FastAPI, Request import uvicorn # ============================================================ # Config # ============================================================ # Default to Mistral (as you wanted). Override in HF Space env vars if needed. # Example: # MODEL_NAME=Qwen/Qwen2.5-1.5B-Instruct MODEL_NAME = os.getenv("MODEL_NAME", "mistralai/Mistral-7B-Instruct-v0.2") MAX_PROMPT_TOKENS = int(os.getenv("MAX_PROMPT_TOKENS", "1024")) MAX_NEW_TOKENS = int(os.getenv("MAX_NEW_TOKENS", "220")) MAX_QUERY_CHARS = 400 MAX_CONTEXT_CHARS = 1500 MAX_RESPONSE_CHARS = 1500 CACHE_MAX_ITEMS = 256 # Stats (in-memory; resets on restart) MAX_STATS_EVENTS = 250 # ============================================================ # Utilities # ============================================================ def _now_ts() -> str: return time.strftime("%Y-%m-%d %H:%M:%S") def _clip(s: str, max_chars: int) -> str: s = (s or "").strip() return s if len(s) <= max_chars else s[:max_chars] + "…" def _cache_key(user_query: str, retrieved_context: str, ai_response: str) -> str: blob = (user_query + "\n" + retrieved_context + "\n" + ai_response).encode("utf-8", errors="ignore") return hashlib.sha256(blob).hexdigest() class LRUCache: def __init__(self, max_items: int = 256): self.max_items = max_items self.store: "OrderedDict[str, Dict[str, Any]]" = OrderedDict() def get(self, key: str) -> Optional[Dict[str, Any]]: if key not in self.store: return None self.store.move_to_end(key) return self.store[key] def set(self, key: str, value: Dict[str, Any]) -> None: self.store[key] = value self.store.move_to_end(key) if len(self.store) > self.max_items: self.store.popitem(last=False) CACHE = LRUCache(CACHE_MAX_ITEMS) def _safe_fallback(reason: str, model_name: str, start_time: float, raw: Optional[str] = None) -> Dict[str, Any]: if raw: reason = f"{reason} Raw output (truncated): {raw[:600]}" return { "hallucination_score": 0.5, "faithfulness_status": "flagged", "risk_level": "medium", "reasoning": reason, "analysis": {"factual_errors": [], "unsupported_claims": [], "contradictions": []}, "recommendation": "review", "performance_metrics": { "inference_time_ms": int((time.time() - start_time) * 1000), "model": model_name, "timestamp": _now_ts(), }, } # ============================================================ # JSON extraction & repair # ============================================================ def _extract_first_json_object(text: str) -> Optional[str]: """ Extract ONLY the first JSON object using brace balancing. Prevents '}{' double JSON issues. """ if not text: return None start = text.find("{") if start < 0: return None depth = 0 for i in range(start, len(text)): ch = text[i] if ch == "{": depth += 1 elif ch == "}": depth -= 1 if depth == 0: return text[start : i + 1] return None def _repair_json(s: str) -> str: """ Light repair: - remove ``` fences - normalize smart quotes - remove trailing commas """ s = (s or "").strip() s = re.sub(r"^```(?:json)?\s*", "", s, flags=re.IGNORECASE).strip() s = re.sub(r"\s*```$", "", s).strip() s = s.replace("“", '"').replace("”", '"').replace("’", "'").replace("‘", "'") s = re.sub(r",(\s*[}\]])", r"\1", s) return s # ============================================================ # Deterministic contradiction check (cheap guardrail) # ============================================================ def _detect_obvious_contradiction(context: str, answer: str) -> Optional[str]: c = (context or "").lower() a = (answer or "").lower() neg = any(p in c for p in [ "not supported", "does not support", "doesn't support", "unsupported", "no support", "not available" ]) pos = any(p in a for p in [ "supports", "supported", "yes", "available", "compatible", "works with", "has thunderbolt" ]) if neg and pos: return "Answer claims support while retrieved context explicitly indicates NOT supported." return None # ============================================================ # Strict normalization (forces consistent outputs) # ============================================================ def _normalize(result: Dict[str, Any], retrieved_context: str, ai_response: str) -> Dict[str, Any]: result = result or {} result.setdefault("analysis", {}) result["analysis"] = result["analysis"] or {} result["analysis"].setdefault("factual_errors", []) result["analysis"].setdefault("unsupported_claims", []) result["analysis"].setdefault("contradictions", []) result.setdefault("reasoning", "") status_map = { "pass": "approved", "ok": "approved", "approve": "approved", "approved": "approved", "flag": "flagged", "flagged": "flagged", "unacceptable": "blocked", "reject": "blocked", "rejected": "blocked", "block": "blocked", "blocked": "blocked", } rec_map = { "approved": "approve", "approve": "approve", "accept": "approve", "review": "review", "flagged": "review", "reject": "block", "rejected": "block", "blocked": "block", "block": "block", "deny": "block", } fs = str(result.get("faithfulness_status", "flagged")).strip().lower() fs = status_map.get(fs, fs) rec = str(result.get("recommendation", "review")).strip().lower() rec = rec_map.get(rec, rec) try: score = float(result.get("hallucination_score", 0.5)) except Exception: score = 0.5 contradiction = _detect_obvious_contradiction(retrieved_context, ai_response) if contradiction: result["analysis"]["contradictions"].append(contradiction) score = max(score, 0.95) if score >= 0.9: fs = "blocked" risk = "critical" rec = "block" elif score >= 0.4: fs = "flagged" risk = "medium" rec = "review" else: fs = "approved" risk = "low" rec = "approve" result["hallucination_score"] = round(float(score), 2) result["faithfulness_status"] = fs result["risk_level"] = risk result["recommendation"] = rec if not result["reasoning"]: if fs == "blocked": result["reasoning"] = "Detected contradiction or strong unfaithfulness relative to retrieved context." elif fs == "flagged": result["reasoning"] = "Potential unfaithfulness/unsupported content detected relative to retrieved context." else: result["reasoning"] = "Response appears faithful to retrieved context." return result # ============================================================ # Prompt # ============================================================ SYSTEM_PROMPT = ( "You are an evaluator for RAG systems. The retrieved context is the ONLY source of truth. " "Return ONLY JSON. No markdown. No extra text." ) USER_PROMPT_TEMPLATE = """Evaluate the AI response against the retrieved context. USER QUERY: {user_query} RETRIEVED CONTEXT: {retrieved_context} AI RESPONSE: {ai_response} Return ONLY JSON exactly in this schema: {{ "hallucination_score": , "faithfulness_status": "", "risk_level": "", "reasoning": "", "analysis": {{ "factual_errors": [""], "unsupported_claims": [""], "contradictions": [""] }}, "recommendation": "" }} Strict rules: - If AI response contradicts retrieved context -> hallucination_score >= 0.9, status=blocked, risk_level=critical, recommendation=block. - If partially unsupported -> score 0.4-0.8, status=flagged, risk_level=medium/high, recommendation=review. - If fully supported -> score <= 0.2, status=approved, risk_level=low, recommendation=approve. Return JSON only. """ # ============================================================ # Load model/tokenizer # ============================================================ print(f"Loading model: {MODEL_NAME}") tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) if tokenizer.pad_token_id is None: tokenizer.pad_token = tokenizer.eos_token torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 model = AutoModelForCausalLM.from_pretrained( MODEL_NAME, torch_dtype=torch_dtype, device_map="auto", ) model.eval() print("Model loaded successfully!") # ============================================================ # Stats (demo usage) # ============================================================ STATS = { "boot_ts": _now_ts(), "total_ui_loads": 0, "total_evals": 0, "unique_clients": set(), # hashed fingerprints "events": [], # last N events } def _get_client_ip(request: Request) -> str: xff = request.headers.get("x-forwarded-for") if xff: return xff.split(",")[0].strip() if request.client: return request.client.host return "unknown" def _fingerprint(request: Request) -> str: ip = _get_client_ip(request) ua = request.headers.get("user-agent", "") return hashlib.sha256(f"{ip}|{ua}".encode("utf-8", errors="ignore")).hexdigest()[:16] def _record_event(event_type: str, request: Optional[Request] = None) -> None: fp = None if request is not None: fp = _fingerprint(request) STATS["unique_clients"].add(fp) STATS["events"].append({"ts": _now_ts(), "type": event_type, "client": fp}) STATS["events"] = STATS["events"][-MAX_STATS_EVENTS:] # ============================================================ # Core evaluation # ============================================================ def evaluate_response(user_query: str, retrieved_context: str, ai_response: str) -> Dict[str, Any]: start_time = time.time() user_query = _clip(user_query, MAX_QUERY_CHARS) retrieved_context = _clip(retrieved_context, MAX_CONTEXT_CHARS) ai_response = _clip(ai_response, MAX_RESPONSE_CHARS) key = _cache_key(user_query, retrieved_context, ai_response) cached = CACHE.get(key) if cached is not None: out = dict(cached) out["performance_metrics"] = { "inference_time_ms": int((time.time() - start_time) * 1000), "model": f"{MODEL_NAME} (cache)", "timestamp": _now_ts(), } return out try: user_prompt = USER_PROMPT_TEMPLATE.format( user_query=user_query, retrieved_context=retrieved_context, ai_response=ai_response, ) messages = [ {"role": "system", "content": SYSTEM_PROMPT}, {"role": "user", "content": user_prompt}, ] if hasattr(tokenizer, "apply_chat_template"): prompt_text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) else: prompt_text = SYSTEM_PROMPT + "\n\n" + user_prompt inputs = tokenizer( prompt_text, return_tensors="pt", max_length=MAX_PROMPT_TOKENS, truncation=True, ) inputs = {k: v.to(model.device) for k, v in inputs.items()} with torch.inference_mode(): outputs = model.generate( **inputs, max_new_tokens=MAX_NEW_TOKENS, do_sample=False, temperature=0.0, top_p=1.0, pad_token_id=tokenizer.pad_token_id, eos_token_id=tokenizer.eos_token_id, ) gen_tokens = outputs[0][inputs["input_ids"].shape[-1] :] gen_text = tokenizer.decode(gen_tokens, skip_special_tokens=True).strip() json_text = _extract_first_json_object(gen_text) if not json_text: result = _safe_fallback("Could not find JSON in model output.", MODEL_NAME, start_time, raw=gen_text) CACHE.set(key, result) return result json_text = _repair_json(json_text) try: result = json.loads(json_text) except Exception: result = _safe_fallback("Model returned invalid JSON.", MODEL_NAME, start_time, raw=gen_text) CACHE.set(key, result) return result result = _normalize(result, retrieved_context, ai_response) result["performance_metrics"] = { "inference_time_ms": int((time.time() - start_time) * 1000), "model": MODEL_NAME, "timestamp": _now_ts(), } CACHE.set(key, result) return result except Exception as e: result = _safe_fallback(f"Evaluation error: {str(e)}", MODEL_NAME, start_time) CACHE.set(key, result) return result # ============================================================ # FastAPI app # ============================================================ app = FastAPI() @app.get("/health") def health(): return {"status": "ok", "model": MODEL_NAME, "boot_ts": STATS["boot_ts"]} @app.get("/stats") def stats(): return { "boot_ts": STATS["boot_ts"], "total_ui_loads": STATS["total_ui_loads"], "total_evals": STATS["total_evals"], "unique_clients_count": len(STATS["unique_clients"]), "recent_events": STATS["events"][-25:], } @app.post("/evaluate") async def evaluate(payload: dict, request: Request): STATS["total_evals"] += 1 _record_event("evaluate", request) return evaluate_response( payload.get("user_query", ""), payload.get("retrieved_context", ""), payload.get("ai_response", ""), ) @app.middleware("http") async def _ui_load_counter(request: Request, call_next): # best-effort UI load counter (root page loads) if request.method == "GET" and request.url.path == "/": STATS["total_ui_loads"] += 1 _record_event("ui_load", request) return await call_next(request) # ============================================================ # Gradio UI (with Dropdown Examples) # ============================================================ def gradio_evaluate(user_query, retrieved_context, ai_response): return json.dumps(evaluate_response(user_query, retrieved_context, ai_response), indent=2) EXAMPLES_MAP = { "Thunderbolt contradiction": ( "Does this laptop support Thunderbolt 4?", "Specs: USB-C only. Thunderbolt 4 NOT supported.", "No support for Thunderbolt 4", ), "Refund window contradiction": ( "What is the refund window?", "Refunds are allowed within 30 days with receipt.", "Refunds are allowed within 60 days with a receipt.", ), "Unsupported pediatric claim": ( "Is this medication safe for children?", "This document only covers adult dosage. No pediatric guidance provided.", "Yes, it is safe for children at half the adult dose.", ), "Warranty contradiction": ( "What is the warranty length?", "Warranty: 1 year limited warranty on parts and labor.", "The warranty is 2 years full coverage.", ), } def load_example(name: str): q, ctx, resp = EXAMPLES_MAP[name] return q, ctx, resp with gr.Blocks(title="RAG-Governance-Evaluator") as demo: gr.Markdown("# rag-governance-evaluator") gr.Markdown("hallucination / faithfulness detection for RAG responses (demo)") with gr.Row(): with gr.Column(): example_choice = gr.Dropdown( label="Examples", choices=list(EXAMPLES_MAP.keys()), value="Thunderbolt contradiction", ) load_btn = gr.Button("Load Example", variant="secondary") query_input = gr.Textbox(label="User Query", lines=2) context_input = gr.Textbox(label="Retrieved Context", lines=6) response_input = gr.Textbox(label="AI Response", lines=4) load_btn.click( fn=load_example, inputs=[example_choice], outputs=[query_input, context_input, response_input], ) evaluate_btn = gr.Button("Evaluate", variant="primary") with gr.Column(): output = gr.Textbox(label="Evaluation Results (JSON)", lines=22) evaluate_btn.click(fn=gradio_evaluate, inputs=[query_input, context_input, response_input], outputs=output) gr.Markdown("### Usage") gr.Markdown( "- Pick an example and click **Load Example**, or type your own inputs.\n" "- Click **Evaluate** to get a structured governance decision.\n" "- View lightweight usage metrics at **/stats** (counts reset on restart)." ) # Mount Gradio at / app = gr.mount_gradio_app(app, demo, path="/", ssr_mode=False) if __name__ == "__main__": uvicorn.run(app, host="0.0.0.0", port=7860)