vishalkhadake30's picture
Update app.py
c26a2f5 verified
import os
import re
import json
import time
import hashlib
from collections import OrderedDict
from typing import Optional, Dict, Any
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from fastapi import FastAPI, Request
import uvicorn
# ============================================================
# Config
# ============================================================
# Default to Mistral (as you wanted). Override in HF Space env vars if needed.
# Example:
# MODEL_NAME=Qwen/Qwen2.5-1.5B-Instruct
MODEL_NAME = os.getenv("MODEL_NAME", "mistralai/Mistral-7B-Instruct-v0.2")
MAX_PROMPT_TOKENS = int(os.getenv("MAX_PROMPT_TOKENS", "1024"))
MAX_NEW_TOKENS = int(os.getenv("MAX_NEW_TOKENS", "220"))
MAX_QUERY_CHARS = 400
MAX_CONTEXT_CHARS = 1500
MAX_RESPONSE_CHARS = 1500
CACHE_MAX_ITEMS = 256
# Stats (in-memory; resets on restart)
MAX_STATS_EVENTS = 250
# ============================================================
# Utilities
# ============================================================
def _now_ts() -> str:
return time.strftime("%Y-%m-%d %H:%M:%S")
def _clip(s: str, max_chars: int) -> str:
s = (s or "").strip()
return s if len(s) <= max_chars else s[:max_chars] + "…"
def _cache_key(user_query: str, retrieved_context: str, ai_response: str) -> str:
blob = (user_query + "\n" + retrieved_context + "\n" + ai_response).encode("utf-8", errors="ignore")
return hashlib.sha256(blob).hexdigest()
class LRUCache:
def __init__(self, max_items: int = 256):
self.max_items = max_items
self.store: "OrderedDict[str, Dict[str, Any]]" = OrderedDict()
def get(self, key: str) -> Optional[Dict[str, Any]]:
if key not in self.store:
return None
self.store.move_to_end(key)
return self.store[key]
def set(self, key: str, value: Dict[str, Any]) -> None:
self.store[key] = value
self.store.move_to_end(key)
if len(self.store) > self.max_items:
self.store.popitem(last=False)
CACHE = LRUCache(CACHE_MAX_ITEMS)
def _safe_fallback(reason: str, model_name: str, start_time: float, raw: Optional[str] = None) -> Dict[str, Any]:
if raw:
reason = f"{reason} Raw output (truncated): {raw[:600]}"
return {
"hallucination_score": 0.5,
"faithfulness_status": "flagged",
"risk_level": "medium",
"reasoning": reason,
"analysis": {"factual_errors": [], "unsupported_claims": [], "contradictions": []},
"recommendation": "review",
"performance_metrics": {
"inference_time_ms": int((time.time() - start_time) * 1000),
"model": model_name,
"timestamp": _now_ts(),
},
}
# ============================================================
# JSON extraction & repair
# ============================================================
def _extract_first_json_object(text: str) -> Optional[str]:
"""
Extract ONLY the first JSON object using brace balancing.
Prevents '}{' double JSON issues.
"""
if not text:
return None
start = text.find("{")
if start < 0:
return None
depth = 0
for i in range(start, len(text)):
ch = text[i]
if ch == "{":
depth += 1
elif ch == "}":
depth -= 1
if depth == 0:
return text[start : i + 1]
return None
def _repair_json(s: str) -> str:
"""
Light repair:
- remove ``` fences
- normalize smart quotes
- remove trailing commas
"""
s = (s or "").strip()
s = re.sub(r"^```(?:json)?\s*", "", s, flags=re.IGNORECASE).strip()
s = re.sub(r"\s*```$", "", s).strip()
s = s.replace("“", '"').replace("”", '"').replace("’", "'").replace("‘", "'")
s = re.sub(r",(\s*[}\]])", r"\1", s)
return s
# ============================================================
# Deterministic contradiction check (cheap guardrail)
# ============================================================
def _detect_obvious_contradiction(context: str, answer: str) -> Optional[str]:
c = (context or "").lower()
a = (answer or "").lower()
neg = any(p in c for p in [
"not supported", "does not support", "doesn't support", "unsupported", "no support", "not available"
])
pos = any(p in a for p in [
"supports", "supported", "yes", "available", "compatible", "works with", "has thunderbolt"
])
if neg and pos:
return "Answer claims support while retrieved context explicitly indicates NOT supported."
return None
# ============================================================
# Strict normalization (forces consistent outputs)
# ============================================================
def _normalize(result: Dict[str, Any], retrieved_context: str, ai_response: str) -> Dict[str, Any]:
result = result or {}
result.setdefault("analysis", {})
result["analysis"] = result["analysis"] or {}
result["analysis"].setdefault("factual_errors", [])
result["analysis"].setdefault("unsupported_claims", [])
result["analysis"].setdefault("contradictions", [])
result.setdefault("reasoning", "")
status_map = {
"pass": "approved",
"ok": "approved",
"approve": "approved",
"approved": "approved",
"flag": "flagged",
"flagged": "flagged",
"unacceptable": "blocked",
"reject": "blocked",
"rejected": "blocked",
"block": "blocked",
"blocked": "blocked",
}
rec_map = {
"approved": "approve",
"approve": "approve",
"accept": "approve",
"review": "review",
"flagged": "review",
"reject": "block",
"rejected": "block",
"blocked": "block",
"block": "block",
"deny": "block",
}
fs = str(result.get("faithfulness_status", "flagged")).strip().lower()
fs = status_map.get(fs, fs)
rec = str(result.get("recommendation", "review")).strip().lower()
rec = rec_map.get(rec, rec)
try:
score = float(result.get("hallucination_score", 0.5))
except Exception:
score = 0.5
contradiction = _detect_obvious_contradiction(retrieved_context, ai_response)
if contradiction:
result["analysis"]["contradictions"].append(contradiction)
score = max(score, 0.95)
if score >= 0.9:
fs = "blocked"
risk = "critical"
rec = "block"
elif score >= 0.4:
fs = "flagged"
risk = "medium"
rec = "review"
else:
fs = "approved"
risk = "low"
rec = "approve"
result["hallucination_score"] = round(float(score), 2)
result["faithfulness_status"] = fs
result["risk_level"] = risk
result["recommendation"] = rec
if not result["reasoning"]:
if fs == "blocked":
result["reasoning"] = "Detected contradiction or strong unfaithfulness relative to retrieved context."
elif fs == "flagged":
result["reasoning"] = "Potential unfaithfulness/unsupported content detected relative to retrieved context."
else:
result["reasoning"] = "Response appears faithful to retrieved context."
return result
# ============================================================
# Prompt
# ============================================================
SYSTEM_PROMPT = (
"You are an evaluator for RAG systems. The retrieved context is the ONLY source of truth. "
"Return ONLY JSON. No markdown. No extra text."
)
USER_PROMPT_TEMPLATE = """Evaluate the AI response against the retrieved context.
USER QUERY:
{user_query}
RETRIEVED CONTEXT:
{retrieved_context}
AI RESPONSE:
{ai_response}
Return ONLY JSON exactly in this schema:
{{
"hallucination_score": <float 0.0-1.0>,
"faithfulness_status": "<approved|flagged|blocked>",
"risk_level": "<low|medium|high|critical>",
"reasoning": "<string>",
"analysis": {{
"factual_errors": ["<string>"],
"unsupported_claims": ["<string>"],
"contradictions": ["<string>"]
}},
"recommendation": "<approve|review|block>"
}}
Strict rules:
- If AI response contradicts retrieved context -> hallucination_score >= 0.9, status=blocked, risk_level=critical, recommendation=block.
- If partially unsupported -> score 0.4-0.8, status=flagged, risk_level=medium/high, recommendation=review.
- If fully supported -> score <= 0.2, status=approved, risk_level=low, recommendation=approve.
Return JSON only.
"""
# ============================================================
# Load model/tokenizer
# ============================================================
print(f"Loading model: {MODEL_NAME}")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
if tokenizer.pad_token_id is None:
tokenizer.pad_token = tokenizer.eos_token
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
torch_dtype=torch_dtype,
device_map="auto",
)
model.eval()
print("Model loaded successfully!")
# ============================================================
# Stats (demo usage)
# ============================================================
STATS = {
"boot_ts": _now_ts(),
"total_ui_loads": 0,
"total_evals": 0,
"unique_clients": set(), # hashed fingerprints
"events": [], # last N events
}
def _get_client_ip(request: Request) -> str:
xff = request.headers.get("x-forwarded-for")
if xff:
return xff.split(",")[0].strip()
if request.client:
return request.client.host
return "unknown"
def _fingerprint(request: Request) -> str:
ip = _get_client_ip(request)
ua = request.headers.get("user-agent", "")
return hashlib.sha256(f"{ip}|{ua}".encode("utf-8", errors="ignore")).hexdigest()[:16]
def _record_event(event_type: str, request: Optional[Request] = None) -> None:
fp = None
if request is not None:
fp = _fingerprint(request)
STATS["unique_clients"].add(fp)
STATS["events"].append({"ts": _now_ts(), "type": event_type, "client": fp})
STATS["events"] = STATS["events"][-MAX_STATS_EVENTS:]
# ============================================================
# Core evaluation
# ============================================================
def evaluate_response(user_query: str, retrieved_context: str, ai_response: str) -> Dict[str, Any]:
start_time = time.time()
user_query = _clip(user_query, MAX_QUERY_CHARS)
retrieved_context = _clip(retrieved_context, MAX_CONTEXT_CHARS)
ai_response = _clip(ai_response, MAX_RESPONSE_CHARS)
key = _cache_key(user_query, retrieved_context, ai_response)
cached = CACHE.get(key)
if cached is not None:
out = dict(cached)
out["performance_metrics"] = {
"inference_time_ms": int((time.time() - start_time) * 1000),
"model": f"{MODEL_NAME} (cache)",
"timestamp": _now_ts(),
}
return out
try:
user_prompt = USER_PROMPT_TEMPLATE.format(
user_query=user_query,
retrieved_context=retrieved_context,
ai_response=ai_response,
)
messages = [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": user_prompt},
]
if hasattr(tokenizer, "apply_chat_template"):
prompt_text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
else:
prompt_text = SYSTEM_PROMPT + "\n\n" + user_prompt
inputs = tokenizer(
prompt_text,
return_tensors="pt",
max_length=MAX_PROMPT_TOKENS,
truncation=True,
)
inputs = {k: v.to(model.device) for k, v in inputs.items()}
with torch.inference_mode():
outputs = model.generate(
**inputs,
max_new_tokens=MAX_NEW_TOKENS,
do_sample=False,
temperature=0.0,
top_p=1.0,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
)
gen_tokens = outputs[0][inputs["input_ids"].shape[-1] :]
gen_text = tokenizer.decode(gen_tokens, skip_special_tokens=True).strip()
json_text = _extract_first_json_object(gen_text)
if not json_text:
result = _safe_fallback("Could not find JSON in model output.", MODEL_NAME, start_time, raw=gen_text)
CACHE.set(key, result)
return result
json_text = _repair_json(json_text)
try:
result = json.loads(json_text)
except Exception:
result = _safe_fallback("Model returned invalid JSON.", MODEL_NAME, start_time, raw=gen_text)
CACHE.set(key, result)
return result
result = _normalize(result, retrieved_context, ai_response)
result["performance_metrics"] = {
"inference_time_ms": int((time.time() - start_time) * 1000),
"model": MODEL_NAME,
"timestamp": _now_ts(),
}
CACHE.set(key, result)
return result
except Exception as e:
result = _safe_fallback(f"Evaluation error: {str(e)}", MODEL_NAME, start_time)
CACHE.set(key, result)
return result
# ============================================================
# FastAPI app
# ============================================================
app = FastAPI()
@app.get("/health")
def health():
return {"status": "ok", "model": MODEL_NAME, "boot_ts": STATS["boot_ts"]}
@app.get("/stats")
def stats():
return {
"boot_ts": STATS["boot_ts"],
"total_ui_loads": STATS["total_ui_loads"],
"total_evals": STATS["total_evals"],
"unique_clients_count": len(STATS["unique_clients"]),
"recent_events": STATS["events"][-25:],
}
@app.post("/evaluate")
async def evaluate(payload: dict, request: Request):
STATS["total_evals"] += 1
_record_event("evaluate", request)
return evaluate_response(
payload.get("user_query", ""),
payload.get("retrieved_context", ""),
payload.get("ai_response", ""),
)
@app.middleware("http")
async def _ui_load_counter(request: Request, call_next):
# best-effort UI load counter (root page loads)
if request.method == "GET" and request.url.path == "/":
STATS["total_ui_loads"] += 1
_record_event("ui_load", request)
return await call_next(request)
# ============================================================
# Gradio UI (with Dropdown Examples)
# ============================================================
def gradio_evaluate(user_query, retrieved_context, ai_response):
return json.dumps(evaluate_response(user_query, retrieved_context, ai_response), indent=2)
EXAMPLES_MAP = {
"Thunderbolt contradiction": (
"Does this laptop support Thunderbolt 4?",
"Specs: USB-C only. Thunderbolt 4 NOT supported.",
"No support for Thunderbolt 4",
),
"Refund window contradiction": (
"What is the refund window?",
"Refunds are allowed within 30 days with receipt.",
"Refunds are allowed within 60 days with a receipt.",
),
"Unsupported pediatric claim": (
"Is this medication safe for children?",
"This document only covers adult dosage. No pediatric guidance provided.",
"Yes, it is safe for children at half the adult dose.",
),
"Warranty contradiction": (
"What is the warranty length?",
"Warranty: 1 year limited warranty on parts and labor.",
"The warranty is 2 years full coverage.",
),
}
def load_example(name: str):
q, ctx, resp = EXAMPLES_MAP[name]
return q, ctx, resp
with gr.Blocks(title="RAG-Governance-Evaluator") as demo:
gr.Markdown("# rag-governance-evaluator")
gr.Markdown("hallucination / faithfulness detection for RAG responses (demo)")
with gr.Row():
with gr.Column():
example_choice = gr.Dropdown(
label="Examples",
choices=list(EXAMPLES_MAP.keys()),
value="Thunderbolt contradiction",
)
load_btn = gr.Button("Load Example", variant="secondary")
query_input = gr.Textbox(label="User Query", lines=2)
context_input = gr.Textbox(label="Retrieved Context", lines=6)
response_input = gr.Textbox(label="AI Response", lines=4)
load_btn.click(
fn=load_example,
inputs=[example_choice],
outputs=[query_input, context_input, response_input],
)
evaluate_btn = gr.Button("Evaluate", variant="primary")
with gr.Column():
output = gr.Textbox(label="Evaluation Results (JSON)", lines=22)
evaluate_btn.click(fn=gradio_evaluate, inputs=[query_input, context_input, response_input], outputs=output)
gr.Markdown("### Usage")
gr.Markdown(
"- Pick an example and click **Load Example**, or type your own inputs.\n"
"- Click **Evaluate** to get a structured governance decision.\n"
"- View lightweight usage metrics at **/stats** (counts reset on restart)."
)
# Mount Gradio at /
app = gr.mount_gradio_app(app, demo, path="/", ssr_mode=False)
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=7860)