SamSankar commited on
Commit
9d28801
·
verified ·
1 Parent(s): b2d2dc5

Upload folder using huggingface_hub

Browse files
Files changed (5) hide show
  1. README.md +168 -280
  2. evaluate_groq.py +64 -0
  3. hallucination_guard_sdk.py +345 -0
  4. server/app.py +243 -70
  5. server/dataset_loader.py +0 -0
README.md CHANGED
@@ -13,359 +13,247 @@ tags:
13
  - grounded-generation
14
  - question-answering
15
  - fact-checking
 
16
  - llm-training
 
17
  ---
18
 
19
- # 🛡️ HallucinationGuard-Env
20
 
21
- > **An OpenEnv reinforcement learning environment that trains AI models to answer only from verified context — penalizing hallucination and rewarding factual grounding.**
22
 
23
- [![OpenEnv](https://img.shields.io/badge/OpenEnv-Compatible-blue)](https://github.com/meta-pytorch/OpenEnv)
24
- [![License](https://img.shields.io/badge/License-MIT-green)](LICENSE)
25
- [![Dataset](https://img.shields.io/badge/Dataset-2000%2B_examples-orange)](#datasets)
 
26
 
27
  ---
28
 
29
- ## 💡 The Inspiration
30
 
31
- During research for the Meta PyTorch OpenEnv Hackathon, an AI model confidently hallucinated a **"golden ticket backdoor"** claiming that Ideathon winners could skip directly to the Grand Finale. This information existed nowhere in the official sources. The AI stated it with high confidence and even fabricated a supporting quote.
32
 
33
- That moment made one thing clear: hallucination isn't just an academic problem. It causes real confusion in high-stakes situations.
34
 
35
- **HallucinationGuard-Env** was built to fix that — training AI models to say *"I don't know"* when they don't, cite real sources when they do, and never fabricate with confidence.
 
 
 
36
 
37
  ---
38
 
39
- ## 🚀 Quick Start
40
 
41
- ```bash
42
- # Install
43
- pip install -e .
44
-
45
- # Run locally
46
- uvicorn server.app:app --reload
47
-
48
- # Health check
49
- curl http://localhost:8000/health
50
- # → {"status": "healthy", "service": "HallucinationGuard-Env"}
51
-
52
- # Deploy to HuggingFace Spaces
53
- openenv push --repo-id your-username/hallucination-guard-env
54
- ```
55
-
56
- ---
57
-
58
- ## 🎮 How The Environment Works
59
-
60
- The agent receives a **question** and a **source document**. It must answer using only what the document says, provide a direct quote supporting its answer, and state how confident it is.
61
-
62
- ### Action Space
63
 
64
  ```python
65
- @dataclass
66
- class HallucinationAction(Action):
67
- answer: str # The agent's answer
68
- confidence: float # Certainty 0.0 → 1.0
69
- source_quote: str # Direct quote from context supporting the answer
70
  ```
71
 
72
- ### Observation Space
73
-
74
  ```python
75
- @dataclass
76
- class HallucinationObservation(Observation):
77
- question: str # The question to answer
78
- context: str # Source document to answer from
79
- reward: float # Step reward
80
- feedback: str # Detailed human-readable feedback
81
- is_hallucination: bool # Was hallucination detected?
82
- hallucination_type: str # Type of hallucination detected
83
- hallucination_severity: str # NONE / MINOR / MODERATE / SEVERE / CRITICAL
84
- grounding_score: float # How well answer is grounded in context
85
- accuracy_so_far: float # Running accuracy this episode
86
- skill_rating: float # ELO-style skill rating
87
- attempts_remaining: int # Steps left in episode
88
- done: bool # Episode complete?
 
 
 
 
 
 
 
89
  ```
90
 
91
- ### Episode Flow
92
 
93
- ```
94
- reset()
95
- → Sample question + context from dataset (curriculum-aware)
96
- → Return initial observation
97
-
98
- step(action)
99
- → Grade answer across 6 components
100
- → Detect hallucination type and severity
101
- → Compute multi-factor reward
102
- → Adapt difficulty based on performance
103
- → Return observation with reward + rich feedback
104
-
105
- state()
106
- → Return episode metadata: ID, step count, skill rating, curriculum stage
107
- ```
108
-
109
- ---
110
 
111
- ## 🏆 Reward System
 
112
 
113
- Six components combine into a single reward signal in **[0.0, 1.0]**:
 
 
 
114
 
115
- | Component | Weight | What It Measures |
116
- |---|---|---|
117
- | **Factual Correctness** | 30% | Semantic similarity + entity overlap vs ground truth |
118
- | **Source Grounding** | 20% | Word coverage and context matching |
119
- | **Citation Accuracy** | 15% | Is source_quote actually in the document? |
120
- | **Confidence Calibration** | 15% | Does stated confidence match actual accuracy? |
121
- | **Semantic Consistency** | 10% | Logical coherence with context |
122
- | **Hallucination Penalty** | 10% | Penalty for fabricated content |
123
 
124
- **Difficulty multipliers:** beginner 0.9× expert 1.2×
125
- **Consistency bonus:** up to +0.05 for sustained high performance
126
 
127
- ```
128
- reward = clamp(Σ(weight × score) × difficulty_multiplier + consistency_bonus, 0.0, 1.0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
129
  ```
130
 
131
- **In practice:**
132
- - Hallucinated answer with false citation → reward ≈ **0.002–0.10**, CRITICAL severity
133
- - Grounded correct answer with real quote → reward ≈ **0.85–1.00**
134
-
135
  ---
136
 
137
- ## 🔬 Hallucination Detection
138
-
139
- ### 8 Types Classified
 
 
 
 
 
 
 
 
 
 
 
 
140
 
141
- | Type | What It Catches |
142
- |---|---|
143
- | `FABRICATED_FACT` | Information stated that is not in the source |
144
- | `FALSE_CITATION` | source_quote that does not exist in the document |
145
- | `OVERCONFIDENT_WRONG` | High confidence on an incorrect answer |
146
- | `CONTEXT_DRIFT` | Answer gradually drifts away from source |
147
- | `NUMERICAL_FABRICATION` | Made-up statistics or numbers |
148
- | `ENTITY_CONFUSION` | Wrong names, organisations, or places |
149
- | `TEMPORAL_ERROR` | Incorrect dates or timelines |
150
- | `RELATIONSHIP_ERROR` | Incorrect relationships between entities |
151
 
152
- ### 5 Severity Levels
153
 
154
- | Level | Score | Meaning |
155
- |---|---|---|
156
- | NONE | 0.0 | Fully grounded answer |
157
- | MINOR | 0.1–0.3 | Slight deviation from source |
158
- | MODERATE | 0.3–0.5 | Noticeable unsupported claims |
159
- | SEVERE | 0.5–0.7 | Significantly fabricated content |
160
- | CRITICAL | 0.7+ | Answer largely invented |
161
 
162
- ### Detection Algorithms
 
 
 
 
 
 
 
163
 
164
- - **Word coverage** fraction of meaningful content words in answer found in context
165
- - **Entity hallucination** — novel entities in answer not found in source
166
- - **Numerical fabrication** — numbers in answer absent from context
167
- - **Sliding window fuzzy matching** — citation verification (threshold 0.7)
168
- - **Negation mismatch** — contradiction detection via negation word analysis
169
- - **Confidence calibration error** — `|confidence − correctness|` with 50% overconfidence surcharge
170
 
171
  ---
172
 
173
- ## 📚 Datasets
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
174
 
175
- 2,140+ total examples loaded at runtime across four difficulty levels:
176
-
177
- | Source | Examples | Type | Difficulty |
178
- |---|---|---|---|
179
- | Synthetic (built-in) | 140 | Hallucination traps, edge cases | All levels |
180
- | **SQuAD** | ~500 | Reading comprehension | Intermediate |
181
- | **TriviaQA** | ~500 | Open-domain factual QA | Intermediate |
182
- | **HaluEval** | ~500 | Hallucination evaluation | Advanced |
183
- | **TruthfulQA** | ~500 | Factuality benchmark | Advanced/Expert |
184
-
185
- Datasets load from Hugging Face automatically on first start (`pip install datasets`).
186
- A local disk cache (`server/cache/`) is used on subsequent starts for instant loading.
187
-
188
- ### Built-in Synthetic Dataset Breakdown
189
-
190
- | Difficulty | Count | Focus |
191
- |---|---|---|
192
- | Beginner | 60 | Simple factual recall, API concepts, basic science |
193
- | Intermediate | 60 | Multi-hop reasoning, history, technology, biology |
194
- | Advanced | 10 | Hallucination traps, common misconceptions |
195
- | Expert | 10 | System mechanics, algorithms, quantum physics |
196
 
197
- ### Add Custom Datasets
198
 
199
- ```python
200
- from server.dataset_loader import DatasetLoader
201
 
202
- loader = DatasetLoader()
203
- loader.load_from_json("my_dataset.json") # Custom JSON
204
- loader.load_from_huggingface("squad") # Any HF dataset
205
  ```
206
-
207
- Custom JSON format:
208
- ```json
209
- [
210
- {
211
- "question": "What is the prize pool?",
212
- "context": "The hackathon has a total prize pool of $30,000 USD...",
213
- "answer": "$30,000 USD",
214
- "id": "q001",
215
- "source": "custom",
216
- "difficulty": "intermediate",
217
- "category": "factual_recall"
218
- }
219
- ]
220
  ```
221
 
222
- ---
223
-
224
- ## 🎓 Curriculum Learning
225
-
226
- The environment adapts difficulty in real-time using an ELO-style skill rating:
227
-
228
- | Trigger | Action |
229
- |---|---|
230
- | Recent avg reward > 0.7 | Increase difficulty |
231
- | Recent avg reward < 0.3 | Decrease difficulty |
232
- | Overall accuracy > 0.8 | EXPERT ceiling |
233
- | Overall accuracy > 0.6 | ADVANCED ceiling |
234
- | Overall accuracy > 0.4 | INTERMEDIATE ceiling |
235
-
236
- Episodes can use progressive difficulty mixing (beginner → expert within one episode) for maximum learning efficiency.
237
 
238
  ---
239
 
240
- ## 🔌 Model-Agnostic Adapters
241
 
242
- Works with any LLM out of the box:
243
 
244
  ```python
245
- from model_adapters import create_adapter
246
-
247
- # OpenAI
248
- adapter = create_adapter("openai", model_name="gpt-4", api_key="sk-...")
249
-
250
- # Anthropic Claude
251
- adapter = create_adapter("anthropic", model_name="claude-sonnet-4-6", api_key="sk-ant-...")
252
-
253
- # HuggingFace (Llama, Mistral, Qwen...)
254
- adapter = create_adapter("huggingface", model_name="meta-llama/Llama-3-8B-Instruct")
255
-
256
- # Local Ollama
257
- adapter = create_adapter("ollama", model_name="llama3", api_base="http://localhost:11434")
258
-
259
- # Use it
260
- response = adapter.generate_response(
261
- question="What is the prize pool?",
262
- context="The hackathon has $30,000 USD in prizes...",
263
- require_citation=True,
264
- require_confidence=True
265
- )
266
  ```
267
 
268
- ---
269
-
270
- ## 📊 Metrics & Monitoring
271
-
272
  ```bash
273
- curl http://localhost:8000/metrics # Live metrics
274
- curl http://localhost:8000/metrics/training-curves # Reward curves
275
- curl http://localhost:8000/metrics/heatmap # Hallucination heatmap
276
- curl http://localhost:8000/metrics/export?format=json # Export data
277
- ```
278
-
279
- Sample output after training:
280
- ```
281
- Episodes: 15 | Steps: 150
282
- Accuracy: 78.5% | Avg Reward: 0.742 | Hallucination Rate: 12.3%
283
- Reward Trend: IMPROVING ↑ | Recent Hallucination Rate: 8.2%
284
  ```
285
 
286
  ---
287
 
288
- ## 🏗️ Project Structure
289
 
290
- ```
291
- hallucination_guard_env/
292
- ├── models.py # HallucinationAction, Observation, State, Config
293
- ├── client.py # HTTP/WebSocket client
294
- ├── model_adapters.py # OpenAI, Anthropic, HuggingFace, Ollama adapters
295
- ├── test_env.py # Full test suite
296
- ├── openenv.yaml # Manifest
297
- ├── pyproject.toml # Package metadata
298
- └── server/
299
- ├── environment.py # Core RL environment logic
300
- ├── app.py # FastAPI server (stateless + session endpoints)
301
- ├── grader.py # 6-component reward + hallucination detection
302
- ├── dataset_loader.py # Multi-source dataset loader with caching
303
- ├── metrics.py # Real-time metrics tracker
304
- ├── cache/ # Pre-built dataset cache (instant startup)
305
- ├── requirements.txt
306
- └── Dockerfile
307
- ```
308
 
309
- ---
 
310
 
311
- ## ⚙️ Configuration
 
312
 
313
- ```python
314
- from models import EnvironmentConfig
315
-
316
- config = EnvironmentConfig(
317
- max_questions_per_episode=10,
318
- reward_weights={
319
- "factual_correctness": 0.30,
320
- "source_grounding": 0.20,
321
- "citation_accuracy": 0.15,
322
- "confidence_calibration": 0.15,
323
- "semantic_consistency": 0.10,
324
- "hallucination_penalty": 0.10,
325
- },
326
- adaptive_difficulty=True,
327
- difficulty_threshold_increase=0.7,
328
- difficulty_threshold_decrease=0.3,
329
- curriculum_enabled=True,
330
- )
331
-
332
- env = HallucinationEnvironment(config=config)
333
- ```
334
 
335
  ---
336
 
337
- ## 🧪 Tests
338
 
339
- ```bash
340
- python test_env.py
341
  ```
342
-
343
- Covers: dataset loading, grader components, reset/step/state, episode completion, hallucination type classification, curriculum difficulty, metrics tracking, model adapter factory.
344
-
345
- ---
346
-
347
- ## 🔗 Links
348
-
349
- | | |
350
- |---|---|
351
- | 📖 OpenEnv Docs | https://github.com/meta-pytorch/OpenEnv |
352
- | 🎓 OpenEnv Course | https://github.com/huggingface/openenv-course |
 
 
 
353
 
354
  ---
355
 
356
- ## 🏆 Why This Environment Stands Out
357
 
358
- | | |
359
- |---|---|
360
- | **Real-world origin** | Born from an actual AI hallucination experience during hackathon research |
361
- | **Solves the #1 LLM problem** | Hallucination is the most critical reliability issue in production AI |
362
- | **Novel** | First OpenEnv environment targeting hallucination and grounding |
363
- | **Rich reward signal** | 6-component system gives models precise, actionable feedback |
364
- | **2,140+ diverse examples** | SQuAD, TriviaQA, HaluEval, TruthfulQA + curated synthetic traps |
365
- | **Model-agnostic** | Works with GPT-4, Claude, Llama, Mistral, or any LLM |
366
- | **Production-ready** | Session management, metrics, caching, Dockerfile included |
367
- | **Adaptive** | ELO-based curriculum scales difficulty with the agent's skill |
368
 
369
  ---
370
 
371
- *Built for the Meta PyTorch OpenEnv Hackathon 2026 · MIT License*
 
13
  - grounded-generation
14
  - question-answering
15
  - fact-checking
16
+ - llm-evaluation
17
  - llm-training
18
+ - benchmark
19
  ---
20
 
21
+ # 🛡️ HallucinationGuard-Env v3.0
22
 
23
+ > **The production-grade OpenEnv RL environment for training and evaluating LLMs on hallucination avoidance.**
24
 
25
+ [![Running](https://img.shields.io/badge/status-running-brightgreen)](https://huggingface.co/spaces/SamSankar/hallucination-guard-env)
26
+ [![OpenEnv](https://img.shields.io/badge/OpenEnv-compatible-blue)](https://github.com/meta-pytorch/OpenEnv)
27
+ [![Datasets](https://img.shields.io/badge/datasets-50k%2B%20examples-orange)](#datasets)
28
+ [![License](https://img.shields.io/badge/license-MIT-green)](LICENSE)
29
 
30
  ---
31
 
32
+ ## Why HallucinationGuard?
33
 
34
+ Large language models hallucinate they confidently state false information not supported by any evidence. This is a critical problem for companies deploying LLMs in production.
35
 
36
+ **HallucinationGuard-Env** provides a standardized, reproducible RL environment to:
37
 
38
+ - 📊 **Benchmark** any LLM's hallucination rate across 50,000+ real-world QA examples
39
+ - 🎯 **Train** models to stay grounded in provided context
40
+ - 🏆 **Compare** models on a public leaderboard
41
+ - 🔧 **Integrate** into any ML pipeline via REST API or Python SDK
42
 
43
  ---
44
 
45
+ ## Quick Start
46
 
47
+ ### Option 1 — Python SDK (recommended)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
 
49
  ```python
50
+ pip install requests
 
 
 
 
51
  ```
52
 
 
 
53
  ```python
54
+ from hallucination_guard_sdk import HallucinationGuardEnv
55
+ import anthropic
56
+
57
+ client = anthropic.Anthropic(api_key="YOUR_KEY")
58
+
59
+ def my_model(question: str, context: str) -> str:
60
+ """Your model function takes question + context, returns answer."""
61
+ msg = client.messages.create(
62
+ model="claude-3-haiku-20240307",
63
+ max_tokens=256,
64
+ messages=[{
65
+ "role": "user",
66
+ "content": f"Context: {context}\n\nQuestion: {question}\n\nAnswer using ONLY the context above."
67
+ }]
68
+ )
69
+ return msg.content[0].text
70
+
71
+ # Evaluate in 3 lines
72
+ env = HallucinationGuardEnv()
73
+ results = env.evaluate(my_model, episodes=10, model_name="claude-3-haiku")
74
+ env.submit_to_leaderboard(results, organization="Anthropic")
75
  ```
76
 
77
+ ### Option 2 — REST API
78
 
79
+ ```bash
80
+ BASE="https://samsankar-hallucination-guard-env.hf.space"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
 
82
+ # Start episode
83
+ curl -X POST $BASE/reset
84
 
85
+ # Submit answer
86
+ curl -X POST $BASE/step \
87
+ -H "Content-Type: application/json" \
88
+ -d '{"answer": "Your answer based only on the context"}'
89
 
90
+ # View leaderboard
91
+ curl $BASE/leaderboard
92
+ ```
 
 
 
 
 
93
 
94
+ ### Option 3 OpenAI compatible
 
95
 
96
+ ```python
97
+ from openai import OpenAI
98
+ from hallucination_guard_sdk import HallucinationGuardEnv
99
+
100
+ client = OpenAI(api_key="YOUR_KEY")
101
+
102
+ def gpt4_model(question, context):
103
+ response = client.chat.completions.create(
104
+ model="gpt-4o-mini",
105
+ messages=[
106
+ {"role": "system", "content": "Answer ONLY from the provided context."},
107
+ {"role": "user", "content": f"Context: {context}\n\nQ: {question}"}
108
+ ]
109
+ )
110
+ return response.choices[0].message.content
111
+
112
+ env = HallucinationGuardEnv()
113
+ results = env.evaluate(gpt4_model, episodes=10, model_name="gpt-4o-mini")
114
+ env.submit_to_leaderboard(results, organization="OpenAI")
115
  ```
116
 
 
 
 
 
117
  ---
118
 
119
+ ## API Reference
120
+
121
+ | Method | Endpoint | Description |
122
+ |--------|----------|-------------|
123
+ | `POST` | `/reset` | Start a new episode, receive first question + context |
124
+ | `POST` | `/step` | Submit answer, receive reward + next question |
125
+ | `GET` | `/state` | Current episode state |
126
+ | `GET` | `/health` | Health check |
127
+ | `POST` | `/session/reset` | Create a stateful multi-turn session |
128
+ | `POST` | `/session/step` | Step within a named session |
129
+ | `GET` | `/leaderboard` | Public model leaderboard |
130
+ | `POST` | `/leaderboard/submit` | Submit evaluation results |
131
+ | `GET` | `/datasets` | Dataset statistics |
132
+ | `GET` | `/metrics` | Real-time usage metrics |
133
+ | `GET` | `/docs` | Interactive Swagger UI |
134
 
135
+ ---
 
 
 
 
 
 
 
 
 
136
 
137
+ ## Reward System
138
 
139
+ Each answer is scored across 6 dimensions:
 
 
 
 
 
 
140
 
141
+ | Component | Weight | Description |
142
+ |-----------|--------|-------------|
143
+ | Factual correctness | 35% | Does the answer match the ground truth? |
144
+ | Source grounding | 30% | Is the answer supported by the context? |
145
+ | Citation accuracy | 15% | Does the answer cite specific context passages? |
146
+ | Confidence calibration | 10% | Is confidence appropriate to accuracy? |
147
+ | Semantic consistency | 5% | Is the answer semantically coherent? |
148
+ | Hallucination penalty | 5% | Was any fabricated content detected? |
149
 
150
+ **Reward range:** -1.0 (complete hallucination) to +1.0 (perfect grounded answer)
 
 
 
 
 
151
 
152
  ---
153
 
154
+ ## Datasets
155
+
156
+ 50,000+ examples across 13 real-world QA datasets:
157
+
158
+ | Dataset | Size | Category | Difficulty |
159
+ |---------|------|----------|------------|
160
+ | SQuAD | 5,000 | Reading comprehension | Intermediate |
161
+ | TriviaQA | 5,000 | Trivia / general knowledge | Intermediate |
162
+ | HaluEval | 2,000 | Hallucination detection | Advanced |
163
+ | TruthfulQA | 817 | Factuality benchmark | Expert |
164
+ | Natural Questions | 5,000 | Open-domain QA | Intermediate |
165
+ | HotpotQA | 5,000 | Multi-hop reasoning | Advanced |
166
+ | BoolQ | 5,000 | Yes/No questions | Beginner |
167
+ | FaithDial | 5,000 | Hallucination in dialogue | Advanced |
168
+ | FEVER | 5,000 | Fact verification | Advanced |
169
+ | ARC-Challenge | 2,000 | Science exam | Advanced |
170
+ | OpenBookQA | 2,000 | Science facts | Intermediate |
171
+ | MS MARCO | 5,000 | Web search QA | Intermediate |
172
+ | CoQA | 5,000 | Conversational QA | Intermediate |
173
 
174
+ ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
175
 
176
+ ## Curriculum Learning
177
 
178
+ The environment implements adaptive difficulty:
 
179
 
 
 
 
180
  ```
181
+ Beginner → Intermediate → Advanced → Expert
182
+ BoolQ SQuAD HotpotQA TruthfulQA
183
+ (yes/no) (reading) (multi-hop) (factuality)
 
 
 
 
 
 
 
 
 
 
 
184
  ```
185
 
186
+ Difficulty adjusts automatically based on the agent's rolling skill rating.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
187
 
188
  ---
189
 
190
+ ## Leaderboard
191
 
192
+ Submit your model's results to the public leaderboard:
193
 
194
  ```python
195
+ env = HallucinationGuardEnv()
196
+ results = env.evaluate(my_model, episodes=10)
197
+ env.submit_to_leaderboard(results, organization="YourCompany")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
198
  ```
199
 
200
+ Or via API:
 
 
 
201
  ```bash
202
+ curl -X POST https://samsankar-hallucination-guard-env.hf.space/leaderboard/submit \
203
+ -H "Content-Type: application/json" \
204
+ -d '{
205
+ "model_name": "gpt-4o",
206
+ "avg_reward": 0.72,
207
+ "avg_accuracy": 0.81,
208
+ "hallucination_rate": 0.19,
209
+ "total_episodes": 10,
210
+ "total_steps": 100,
211
+ "organization": "OpenAI"
212
+ }'
213
  ```
214
 
215
  ---
216
 
217
+ ## Use Cases
218
 
219
+ ### For AI Companies
220
+ Benchmark your models before deployment. Compare across model versions. Track hallucination regression.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
221
 
222
+ ### For Researchers
223
+ Standardized evaluation protocol. 50k+ diverse examples. Reproducible results via seed parameter.
224
 
225
+ ### For Developers
226
+ REST API — works with any language. Python SDK — 3 lines to evaluate. Per-dataset caching for fast iteration.
227
 
228
+ ### For RL Training
229
+ Full OpenEnv-compatible interface. Curriculum learning built-in. Reward signal optimized for RL training loops.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
230
 
231
  ---
232
 
233
+ ## Architecture
234
 
 
 
235
  ```
236
+ ┌─────────────────────────────────────────────────┐
237
+ │ FastAPI Server │
238
+ │ /reset → /step → reward signal → /leaderboard │
239
+ ├─────────────────────────────────────────────────┤
240
+ │ HallucinationEnvironment │
241
+ │ Episode management · Curriculum learning │
242
+ ├─────────────────────────────────────────────────┤
243
+ │ Grader │
244
+ │ Semantic similarity · NLI · Citation detection │
245
+ ├─────────────────────────────────────────────────┤
246
+ │ Dataset Loader │
247
+ │ 13 datasets · 50k+ examples · Per-file cache │
248
+ └─────────────────────────────────────────────────┘
249
+ ```
250
 
251
  ---
252
 
253
+ ## License
254
 
255
+ MIT License — free for research and commercial use.
 
 
 
 
 
 
 
 
 
256
 
257
  ---
258
 
259
+ *Built for the Meta PyTorch OpenEnv Hackathon 2026*
evaluate_groq.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ HallucinationGuard-Env — Groq/Llama Evaluator (SDK version)
3
+ Uses the HallucinationGuard SDK + Groq free tier
4
+
5
+ Setup:
6
+ pip install groq requests
7
+ Get free key at https://console.groq.com
8
+ python evaluate_groq.py --api-key YOUR_GROQ_KEY --episodes 5
9
+ """
10
+
11
+ import argparse
12
+ import sys
13
+
14
+ try:
15
+ from groq import Groq
16
+ except ImportError:
17
+ print("Run: pip install groq requests")
18
+ sys.exit(1)
19
+
20
+ from hallucination_guard_sdk import HallucinationGuardEnv
21
+
22
+ MODEL = "llama-3.1-8b-instant"
23
+
24
+ SYSTEM = """Answer questions using ONLY the provided context.
25
+ If the context lacks real information, say: "The context does not contain enough information."
26
+ Never use outside knowledge. Be concise."""
27
+
28
+ def make_model_fn(client):
29
+ def model_fn(question: str, context: str) -> str:
30
+ r = client.chat.completions.create(
31
+ model=MODEL,
32
+ messages=[
33
+ {"role": "system", "content": SYSTEM},
34
+ {"role": "user", "content": f"Context:\n{context}\n\nQuestion: {question}"},
35
+ ],
36
+ max_tokens=200,
37
+ temperature=0.1,
38
+ )
39
+ return r.choices[0].message.content.strip()
40
+ return model_fn
41
+
42
+ def main():
43
+ parser = argparse.ArgumentParser()
44
+ parser.add_argument("--api-key", required=True)
45
+ parser.add_argument("--episodes", type=int, default=5)
46
+ parser.add_argument("--model-name", default="llama-3.1-8b-groq")
47
+ parser.add_argument("--organization", default="")
48
+ parser.add_argument("--submit", action="store_true",
49
+ help="Submit results to leaderboard")
50
+ args = parser.parse_args()
51
+
52
+ client = Groq(api_key=args.api_key)
53
+ model_fn = make_model_fn(client)
54
+
55
+ env = HallucinationGuardEnv()
56
+ results = env.evaluate(model_fn, episodes=args.episodes,
57
+ model_name=args.model_name)
58
+ env.save_results(results)
59
+
60
+ if args.submit:
61
+ env.submit_to_leaderboard(results, organization=args.organization)
62
+
63
+ if __name__ == "__main__":
64
+ main()
hallucination_guard_sdk.py ADDED
@@ -0,0 +1,345 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ HallucinationGuard SDK v3.0
3
+ ===========================
4
+ The easiest way to evaluate any LLM for hallucination using HallucinationGuard-Env.
5
+
6
+ Install:
7
+ pip install requests
8
+
9
+ Usage (3 lines):
10
+ from hallucination_guard_sdk import HallucinationGuardEnv
11
+ env = HallucinationGuardEnv()
12
+ results = env.evaluate(your_model_fn, episodes=5)
13
+
14
+ Full example:
15
+ import anthropic
16
+ client = anthropic.Anthropic(api_key="...")
17
+
18
+ def my_model(question, context):
19
+ msg = client.messages.create(
20
+ model="claude-3-haiku-20240307",
21
+ max_tokens=256,
22
+ messages=[{"role": "user", "content": f"Context: {context}\\n\\nQuestion: {question}\\n\\nAnswer using ONLY the context."}]
23
+ )
24
+ return msg.content[0].text
25
+
26
+ env = HallucinationGuardEnv()
27
+ results = env.evaluate(my_model, episodes=5, model_name="claude-3-haiku")
28
+ env.print_report(results)
29
+ env.submit_to_leaderboard(results)
30
+ """
31
+
32
+ import time
33
+ import json
34
+ import sys
35
+ from typing import Callable, Optional, Dict, Any, List
36
+
37
+ try:
38
+ import requests
39
+ except ImportError:
40
+ print("Run: pip install requests")
41
+ sys.exit(1)
42
+
43
+
44
+ class HallucinationGuardEnv:
45
+ """
46
+ Python SDK for HallucinationGuard-Env.
47
+
48
+ Parameters
49
+ ----------
50
+ base_url : str
51
+ URL of the deployed environment. Defaults to the live HF Space.
52
+ verbose : bool
53
+ Print step-by-step output during evaluation.
54
+ """
55
+
56
+ BASE_URL = "https://samsankar-hallucination-guard-env.hf.space"
57
+
58
+ def __init__(
59
+ self,
60
+ base_url: str = BASE_URL,
61
+ verbose: bool = True,
62
+ ):
63
+ self.base_url = base_url.rstrip("/")
64
+ self.verbose = verbose
65
+ self._check_health()
66
+
67
+ # ── Core methods ──────────────────────────────────────────────────────────
68
+
69
+ def reset(self, difficulty: Optional[str] = None, seed: Optional[int] = None) -> Dict:
70
+ """Reset the environment. Returns the first observation."""
71
+ body = {}
72
+ if difficulty: body["difficulty"] = difficulty
73
+ if seed is not None: body["seed"] = seed
74
+ return self._post("/reset", body)
75
+
76
+ def step(self, answer: str) -> Dict:
77
+ """Submit an answer. Returns reward, hallucination flag, feedback, next question."""
78
+ return self._post("/step", {"answer": answer})
79
+
80
+ def health(self) -> Dict:
81
+ """Check if the environment is running."""
82
+ return self._get("/health")
83
+
84
+ def leaderboard(self) -> Dict:
85
+ """Get the current leaderboard."""
86
+ return self._get("/leaderboard")
87
+
88
+ def dataset_info(self) -> Dict:
89
+ """Get statistics about loaded datasets."""
90
+ return self._get("/datasets")
91
+
92
+ # ── High-level evaluate() ─────────────────────────────────────────────────
93
+
94
+ def evaluate(
95
+ self,
96
+ model_fn: Callable[[str, str], str],
97
+ episodes: int = 3,
98
+ difficulty: Optional[str] = None,
99
+ model_name: str = "my_model",
100
+ delay: float = 0.5,
101
+ ) -> Dict[str, Any]:
102
+ """
103
+ Run a full evaluation of your model against the environment.
104
+
105
+ Parameters
106
+ ----------
107
+ model_fn : callable
108
+ Function that takes (question: str, context: str) → answer: str
109
+ episodes : int
110
+ Number of episodes to run (default: 3)
111
+ difficulty : str, optional
112
+ Force a difficulty level: beginner | intermediate | advanced | expert
113
+ model_name : str
114
+ Name for the leaderboard
115
+ delay : float
116
+ Seconds to wait between API calls (be gentle with free tier)
117
+
118
+ Returns
119
+ -------
120
+ dict with summary stats and full episode logs
121
+
122
+ Example
123
+ -------
124
+ >>> def my_model(question, context):
125
+ ... # call your LLM here
126
+ ... return "answer from context"
127
+ >>> env = HallucinationGuardEnv()
128
+ >>> results = env.evaluate(my_model, episodes=5)
129
+ """
130
+ if self.verbose:
131
+ print(f"\n🛡️ HallucinationGuard-Env — Evaluating: {model_name}")
132
+ print(f" Episodes : {episodes}")
133
+ print(f" Difficulty: {difficulty or 'mixed'}")
134
+ print(f" Endpoint : {self.base_url}\n")
135
+
136
+ all_episodes = []
137
+
138
+ for ep_num in range(1, episodes + 1):
139
+ if self.verbose:
140
+ print(f"{'='*60}")
141
+ print(f" EPISODE {ep_num}/{episodes}")
142
+ print(f"{'='*60}")
143
+
144
+ ep_result = self._run_episode(model_fn, ep_num, difficulty, delay)
145
+ all_episodes.append(ep_result)
146
+
147
+ if self.verbose:
148
+ print(f" ─ Episode {ep_num} complete │ "
149
+ f"accuracy: {ep_result['accuracy']*100:.0f}% │ "
150
+ f"reward: {ep_result['avg_reward']:.3f} │ "
151
+ f"hallucinations: {ep_result['hallucinations']}/{ep_result['steps']}")
152
+
153
+ time.sleep(delay)
154
+
155
+ # ── Aggregate ─────────────────────────────────────────────────────────
156
+ total_steps = sum(e["steps"] for e in all_episodes)
157
+ total_halluc = sum(e["hallucinations"] for e in all_episodes)
158
+ avg_accuracy = sum(e["accuracy"] for e in all_episodes) / len(all_episodes)
159
+ avg_reward = sum(e["avg_reward"] for e in all_episodes) / len(all_episodes)
160
+ avg_skill = sum(e["final_skill"] for e in all_episodes) / len(all_episodes)
161
+ best_streak = max(e["best_streak"] for e in all_episodes)
162
+ halluc_rate = total_halluc / max(total_steps, 1)
163
+
164
+ results = {
165
+ "model_name": model_name,
166
+ "episodes": episodes,
167
+ "total_steps": total_steps,
168
+ "avg_accuracy": round(avg_accuracy, 4),
169
+ "avg_reward": round(avg_reward, 4),
170
+ "hallucination_rate": round(halluc_rate, 4),
171
+ "best_streak": best_streak,
172
+ "avg_skill_rating": round(avg_skill, 4),
173
+ "episode_logs": all_episodes,
174
+ }
175
+
176
+ if self.verbose:
177
+ self.print_report(results)
178
+
179
+ return results
180
+
181
+ def _run_episode(self, model_fn, ep_num, difficulty, delay) -> Dict:
182
+ obs = self.reset(difficulty=difficulty)
183
+ step_logs = []
184
+ step = 0
185
+
186
+ while not obs.get("done", False):
187
+ question = obs.get("question", "")
188
+ context = obs.get("context", "")
189
+ step += 1
190
+
191
+ if not question:
192
+ break
193
+
194
+ if self.verbose:
195
+ q_display = question[:75] + "..." if len(question) > 75 else question
196
+ print(f"\n Step {step} [{obs.get('source_dataset','?')}]")
197
+ print(f" Q: {q_display}")
198
+
199
+ # Call the model
200
+ try:
201
+ answer = model_fn(question, context)
202
+ except Exception as e:
203
+ answer = f"Error calling model: {e}"
204
+
205
+ if self.verbose:
206
+ a_display = answer[:90] + "..." if len(answer) > 90 else answer
207
+ print(f" A: {a_display}")
208
+
209
+ obs = self.step(answer)
210
+
211
+ reward = obs.get("reward", 0) or 0
212
+ is_halluc = obs.get("is_hallucination", False)
213
+ status = "❌ HALLUCINATION" if is_halluc else "✅ OK"
214
+
215
+ if self.verbose:
216
+ print(f" {status} │ reward: {reward:.3f} │ skill: {obs.get('skill_rating', 0):.3f}")
217
+
218
+ step_logs.append({
219
+ "step": step,
220
+ "question": question,
221
+ "answer": answer,
222
+ "reward": reward,
223
+ "is_hallucination": is_halluc,
224
+ "hallucination_type": obs.get("hallucination_type"),
225
+ "source": obs.get("source_dataset", ""),
226
+ })
227
+
228
+ time.sleep(delay)
229
+
230
+ accuracy = obs.get("accuracy_so_far", 0)
231
+ best_streak = obs.get("best_streak", 0)
232
+ final_skill = obs.get("skill_rating", 0)
233
+ avg_reward = sum(s["reward"] for s in step_logs) / max(len(step_logs), 1)
234
+ hallucinations = sum(1 for s in step_logs if s["is_hallucination"])
235
+
236
+ return {
237
+ "episode": ep_num,
238
+ "steps": len(step_logs),
239
+ "accuracy": accuracy,
240
+ "avg_reward": avg_reward,
241
+ "best_streak": best_streak,
242
+ "hallucinations": hallucinations,
243
+ "final_skill": final_skill,
244
+ "step_logs": step_logs,
245
+ }
246
+
247
+ # ── Reporting ──────────────────────────────────────────────────────────────
248
+
249
+ def print_report(self, results: Dict) -> None:
250
+ """Print a formatted evaluation report."""
251
+ print(f"\n{'='*60}")
252
+ print(f" 📊 EVALUATION REPORT — {results['model_name']}")
253
+ print(f"{'='*60}")
254
+ print(f" Episodes run : {results['episodes']}")
255
+ print(f" Total steps : {results['total_steps']}")
256
+ print(f" Avg accuracy : {results['avg_accuracy']*100:.1f}%")
257
+ print(f" Avg reward : {results['avg_reward']:.4f}")
258
+ print(f" Hallucination rate : {results['hallucination_rate']*100:.1f}%")
259
+ print(f" Best answer streak : {results['best_streak']}")
260
+ print(f" Avg skill rating : {results['avg_skill_rating']:.4f}")
261
+ print(f"{'='*60}\n")
262
+
263
+ def save_results(self, results: Dict, filepath: str = "evaluation_results.json") -> None:
264
+ """Save evaluation results to a JSON file."""
265
+ with open(filepath, "w") as f:
266
+ json.dump(results, f, indent=2)
267
+ print(f"Results saved to: {filepath}")
268
+
269
+ def submit_to_leaderboard(
270
+ self,
271
+ results: Dict,
272
+ organization: str = "",
273
+ notes: str = "",
274
+ ) -> Dict:
275
+ """
276
+ Submit your evaluation results to the public leaderboard.
277
+
278
+ Parameters
279
+ ----------
280
+ results : dict
281
+ Output from evaluate()
282
+ organization : str
283
+ Your company/institution name
284
+ notes : str
285
+ Any notes about the evaluation setup
286
+ """
287
+ payload = {
288
+ "model_name": results["model_name"],
289
+ "avg_reward": results["avg_reward"],
290
+ "avg_accuracy": results["avg_accuracy"],
291
+ "hallucination_rate": results["hallucination_rate"],
292
+ "total_episodes": results["episodes"],
293
+ "total_steps": results["total_steps"],
294
+ "organization": organization,
295
+ "notes": notes,
296
+ }
297
+ response = self._post("/leaderboard/submit", payload)
298
+ if self.verbose:
299
+ print(f"🏆 Submitted to leaderboard: {results['model_name']}")
300
+ print(f" View at: {self.base_url}/leaderboard")
301
+ return response
302
+
303
+ # ── HTTP helpers ───────────────────────────────────────────────────────────
304
+
305
+ def _get(self, path: str) -> Dict:
306
+ try:
307
+ r = requests.get(f"{self.base_url}{path}", timeout=30)
308
+ r.raise_for_status()
309
+ return r.json()
310
+ except Exception as e:
311
+ raise ConnectionError(f"GET {path} failed: {e}")
312
+
313
+ def _post(self, path: str, body: Dict = {}) -> Dict:
314
+ try:
315
+ r = requests.post(f"{self.base_url}{path}", json=body, timeout=30)
316
+ r.raise_for_status()
317
+ return r.json()
318
+ except Exception as e:
319
+ raise ConnectionError(f"POST {path} failed: {e}")
320
+
321
+ def _check_health(self) -> None:
322
+ try:
323
+ h = self._get("/health")
324
+ if self.verbose:
325
+ print(f"✅ Connected to HallucinationGuard-Env ({h.get('version','?')})")
326
+ except Exception as e:
327
+ print(f"⚠️ Could not reach {self.base_url}: {e}")
328
+
329
+
330
+ # ── CLI quick-test ─────────────────────────────────────────────────────────────
331
+
332
+ if __name__ == "__main__":
333
+ """Quick smoke-test using a simple rule-based 'model'."""
334
+
335
+ def dummy_model(question: str, context: str) -> str:
336
+ """Answers only from context — extracts a key phrase."""
337
+ words = context.split()
338
+ if len(words) > 5:
339
+ return " ".join(words[:10])
340
+ return context
341
+
342
+ env = HallucinationGuardEnv()
343
+ results = env.evaluate(dummy_model, episodes=2, model_name="dummy-baseline")
344
+ env.save_results(results, "dummy_results.json")
345
+ env.submit_to_leaderboard(results, organization="Test Org", notes="Baseline run")
server/app.py CHANGED
@@ -1,35 +1,89 @@
1
- """FastAPI server for HallucinationGuard-Env with session management.
2
-
3
- Standard endpoints (/reset, /step, /state, /health) — stateless, new env per request.
4
- Session endpoints (/session/reset, /session/step) — stateful, env persists across calls.
 
 
 
 
 
5
  """
6
 
7
- import sys, os, uuid, logging, dataclasses, enum
8
  sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
9
  sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
10
 
11
  from fastapi import FastAPI, HTTPException, Header
12
  from fastapi.responses import JSONResponse, RedirectResponse
13
- from typing import Dict, Any, Optional
 
14
 
15
  from models import HallucinationAction, HallucinationObservation, HallucinationState
16
  from environment import HallucinationEnvironment
17
  from metrics import get_tracker
18
 
19
- logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
 
20
  logger = logging.getLogger(__name__)
21
 
22
  app = FastAPI(
23
  title="HallucinationGuard-Env",
24
- description="OpenEnv RL environment for training AI to avoid hallucinations",
25
- version="2.0.0",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  )
27
 
28
- # Session storage for stateful HTTP interactions
29
  _sessions: Dict[str, HallucinationEnvironment] = {}
30
- # Shared stateless env instance for standard endpoints
31
  _default_env: Optional[HallucinationEnvironment] = None
32
 
 
 
 
33
 
34
  def _get_default_env() -> HallucinationEnvironment:
35
  global _default_env
@@ -39,12 +93,8 @@ def _get_default_env() -> HallucinationEnvironment:
39
 
40
 
41
  def _safe_dict(obj):
42
- """Recursively convert dataclass/enum/dict to JSON-safe structure."""
43
  if dataclasses.is_dataclass(obj):
44
- result = {}
45
- for f in dataclasses.fields(obj):
46
- result[f.name] = _safe_dict(getattr(obj, f.name))
47
- return result
48
  elif isinstance(obj, enum.Enum):
49
  return obj.value
50
  elif isinstance(obj, dict):
@@ -56,11 +106,27 @@ def _safe_dict(obj):
56
  return str(obj)
57
 
58
 
 
 
 
 
 
 
 
59
  # ── Standard stateless endpoints ──────────────────────────────────────────────
60
 
61
- @app.post("/reset")
62
  async def reset(body: Dict[str, Any] = {}):
63
- """Reset environment and return initial observation."""
 
 
 
 
 
 
 
 
 
64
  try:
65
  env = _get_default_env()
66
  obs = env.reset(**{k: v for k, v in body.items()
@@ -71,9 +137,19 @@ async def reset(body: Dict[str, Any] = {}):
71
  raise HTTPException(status_code=500, detail=str(e))
72
 
73
 
74
- @app.post("/step")
75
  async def step(action_data: Dict[str, Any]):
76
- """Take a step with the provided action."""
 
 
 
 
 
 
 
 
 
 
77
  try:
78
  env = _get_default_env()
79
  valid = {f.name for f in dataclasses.fields(HallucinationAction)}
@@ -84,23 +160,25 @@ async def step(action_data: Dict[str, Any]):
84
  raise HTTPException(status_code=500, detail=str(e))
85
 
86
 
87
- @app.get("/state")
88
  async def get_state():
89
- """Get current environment state."""
90
  try:
91
  return JSONResponse(content=_safe_dict(_get_default_env().state()))
92
  except Exception as e:
93
  raise HTTPException(status_code=500, detail=str(e))
94
 
95
 
96
- # ── Session-based stateful endpoints ──────────────────────────────────────────
97
 
98
- @app.post("/session/reset")
99
- async def session_reset(
100
- body: Dict[str, Any] = {},
101
- x_session_id: Optional[str] = Header(None),
102
- ) -> Dict[str, Any]:
103
- """Create or reset a named session."""
 
 
104
  session_id = x_session_id or str(uuid.uuid4())
105
  if session_id in _sessions:
106
  _sessions[session_id].close()
@@ -110,16 +188,13 @@ async def session_reset(
110
  "enable_multi_turn", "enable_context_retrieval")})
111
  result = _safe_dict(obs)
112
  result["session_id"] = session_id
113
- logger.info(f"Created session {session_id}")
114
  return result
115
 
116
 
117
- @app.post("/session/step")
118
- async def session_step(
119
- action_data: Dict[str, Any],
120
- x_session_id: str = Header(...),
121
- ) -> Dict[str, Any]:
122
- """Execute a step in an existing session."""
123
  if x_session_id not in _sessions:
124
  raise HTTPException(status_code=404,
125
  detail=f"Session {x_session_id} not found. Call /session/reset first.")
@@ -131,8 +206,8 @@ async def session_step(
131
  return result
132
 
133
 
134
- @app.delete("/session")
135
- async def close_session(x_session_id: str = Header(...)) -> Dict[str, str]:
136
  """Close and clean up a session."""
137
  if x_session_id in _sessions:
138
  _sessions[x_session_id].close()
@@ -140,19 +215,140 @@ async def close_session(x_session_id: str = Header(...)) -> Dict[str, str]:
140
  return {"status": "closed", "session_id": x_session_id}
141
 
142
 
143
- @app.get("/session/list")
144
- async def list_sessions() -> Dict[str, Any]:
145
  return {"active_sessions": len(_sessions), "session_ids": list(_sessions.keys())}
146
 
147
 
148
- # ── Utility endpoints ──────────────────────────────────────────────────────────
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
149
 
150
- @app.get("/health")
 
 
151
  async def health():
152
- return {"status": "healthy", "service": "HallucinationGuard-Env", "version": "2.0.0"}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
153
 
154
 
155
- @app.get("/metrics")
156
  async def get_metrics():
157
  try:
158
  return get_tracker().get_real_time_metrics()
@@ -160,7 +356,7 @@ async def get_metrics():
160
  raise HTTPException(status_code=500, detail=str(e))
161
 
162
 
163
- @app.get("/metrics/summary")
164
  async def metrics_summary():
165
  try:
166
  return {"summary": get_tracker().generate_summary_report()}
@@ -168,24 +364,7 @@ async def metrics_summary():
168
  raise HTTPException(status_code=500, detail=str(e))
169
 
170
 
171
- @app.get("/environment/info")
172
- async def env_info():
173
- return {
174
- "name": "HallucinationGuard-Env",
175
- "version": "2.0.0",
176
- "endpoints": {
177
- "standard": ["/reset", "/step", "/state", "/health"],
178
- "session": ["/session/reset", "/session/step", "/session", "/session/list"],
179
- "metrics": ["/metrics", "/metrics/summary"],
180
- },
181
- "difficulty_levels": ["beginner", "intermediate", "advanced", "expert"],
182
- "hallucination_types": [
183
- "fabricated_fact", "false_citation", "overconfident_wrong",
184
- "context_drift", "numerical_fabrication", "entity_confusion",
185
- ],
186
- "supported_models": ["openai", "anthropic", "huggingface", "ollama", "generic"],
187
- }
188
-
189
 
190
  @app.middleware("http")
191
  async def log_requests(request, call_next):
@@ -194,12 +373,6 @@ async def log_requests(request, call_next):
194
  return response
195
 
196
 
197
-
198
- @app.get("/")
199
- async def root():
200
- return RedirectResponse(url="/docs")
201
-
202
-
203
  if __name__ == "__main__":
204
  import uvicorn
205
- uvicorn.run(app, host="0.0.0.0", port=8000)
 
1
+ """
2
+ HallucinationGuard-Env v3.0 — Production FastAPI Server
3
+
4
+ Endpoints:
5
+ Standard : POST /reset POST /step GET /state GET /health
6
+ Session : POST /session/reset POST /session/step DELETE /session
7
+ Leaderboard: GET /leaderboard POST /leaderboard/submit DELETE /leaderboard/{model}
8
+ Info : GET / GET /docs GET /environment/info GET /datasets
9
+ GET /metrics GET /metrics/summary
10
  """
11
 
12
+ import sys, os, uuid, logging, dataclasses, enum, time
13
  sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
14
  sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
15
 
16
  from fastapi import FastAPI, HTTPException, Header
17
  from fastapi.responses import JSONResponse, RedirectResponse
18
+ from fastapi.middleware.cors import CORSMiddleware
19
+ from typing import Dict, Any, Optional, List
20
 
21
  from models import HallucinationAction, HallucinationObservation, HallucinationState
22
  from environment import HallucinationEnvironment
23
  from metrics import get_tracker
24
 
25
+ logging.basicConfig(level=logging.INFO,
26
+ format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
27
  logger = logging.getLogger(__name__)
28
 
29
  app = FastAPI(
30
  title="HallucinationGuard-Env",
31
+ description="""
32
+ ## 🛡️ HallucinationGuard-Env v3.0
33
+
34
+ **The production-grade OpenEnv RL environment for training and evaluating LLMs on hallucination avoidance.**
35
+
36
+ Built on 50,000+ examples across 13 real-world QA datasets:
37
+ SQuAD · TriviaQA · HaluEval · TruthfulQA · Natural Questions · HotpotQA ·
38
+ BoolQ · FaithDial · FEVER · ARC · OpenBookQA · MS MARCO · CoQA
39
+
40
+ ### Quick Start
41
+
42
+ ```python
43
+ pip install requests
44
+ import requests
45
+
46
+ BASE = "https://samsankar-hallucination-guard-env.hf.space"
47
+
48
+ # 1. Start episode
49
+ obs = requests.post(f"{BASE}/reset").json()
50
+ print(obs["question"], obs["context"])
51
+
52
+ # 2. Answer from context only
53
+ result = requests.post(f"{BASE}/step", json={"answer": "your answer"}).json()
54
+ print(result["reward"], result["is_hallucination"])
55
+ ```
56
+
57
+ ### Python SDK
58
+
59
+ ```python
60
+ pip install hallucination-guard-sdk # coming soon
61
+ from hallucination_guard import HallucinationGuardEnv
62
+ env = HallucinationGuardEnv()
63
+ obs = env.reset()
64
+ result = env.step(obs["question"], obs["context"], your_model)
65
+ ```
66
+ """,
67
+ version="3.0.0",
68
+ contact={"name": "HallucinationGuard", "url": "https://huggingface.co/spaces/SamSankar/hallucination-guard-env"},
69
+ license_info={"name": "MIT"},
70
+ )
71
+
72
+ # CORS — allow all origins so any company/researcher can call this API
73
+ app.add_middleware(
74
+ CORSMiddleware,
75
+ allow_origins=["*"],
76
+ allow_methods=["*"],
77
+ allow_headers=["*"],
78
  )
79
 
80
+ # ── State ──────────────────────────────────────────────────────────────────────
81
  _sessions: Dict[str, HallucinationEnvironment] = {}
 
82
  _default_env: Optional[HallucinationEnvironment] = None
83
 
84
+ # Leaderboard: { model_name: {score, hallucination_rate, episodes, submitted_at} }
85
+ _leaderboard: Dict[str, Dict[str, Any]] = {}
86
+
87
 
88
  def _get_default_env() -> HallucinationEnvironment:
89
  global _default_env
 
93
 
94
 
95
  def _safe_dict(obj):
 
96
  if dataclasses.is_dataclass(obj):
97
+ return {f.name: _safe_dict(getattr(obj, f.name)) for f in dataclasses.fields(obj)}
 
 
 
98
  elif isinstance(obj, enum.Enum):
99
  return obj.value
100
  elif isinstance(obj, dict):
 
106
  return str(obj)
107
 
108
 
109
+ # ── Root ───────────────────────────────────────────────────────────────────────
110
+
111
+ @app.get("/", include_in_schema=False)
112
+ async def root():
113
+ return RedirectResponse(url="/docs")
114
+
115
+
116
  # ── Standard stateless endpoints ──────────────────────────────────────────────
117
 
118
+ @app.post("/reset", summary="Start a new episode", tags=["Environment"])
119
  async def reset(body: Dict[str, Any] = {}):
120
+ """
121
+ Reset the environment and receive the first question + context.
122
+
123
+ **Returns:** question, context, difficulty, attempts_remaining, skill_rating
124
+
125
+ **Optional body params:**
126
+ - `seed` (int): reproducible episode
127
+ - `difficulty` (str): beginner | intermediate | advanced | expert
128
+ - `episode_id` (str): custom episode ID
129
+ """
130
  try:
131
  env = _get_default_env()
132
  obs = env.reset(**{k: v for k, v in body.items()
 
137
  raise HTTPException(status_code=500, detail=str(e))
138
 
139
 
140
+ @app.post("/step", summary="Submit an answer", tags=["Environment"])
141
  async def step(action_data: Dict[str, Any]):
142
+ """
143
+ Submit an answer to the current question.
144
+
145
+ **Body:**
146
+ ```json
147
+ {"answer": "Your answer based ONLY on the provided context"}
148
+ ```
149
+
150
+ **Returns:** reward (-1 to 1), is_hallucination, hallucination_type,
151
+ grounding_score, feedback, next question + context
152
+ """
153
  try:
154
  env = _get_default_env()
155
  valid = {f.name for f in dataclasses.fields(HallucinationAction)}
 
160
  raise HTTPException(status_code=500, detail=str(e))
161
 
162
 
163
+ @app.get("/state", summary="Get current episode state", tags=["Environment"])
164
  async def get_state():
165
+ """Returns full episode state: step count, accuracy, skill rating, streaks."""
166
  try:
167
  return JSONResponse(content=_safe_dict(_get_default_env().state()))
168
  except Exception as e:
169
  raise HTTPException(status_code=500, detail=str(e))
170
 
171
 
172
+ # ── Session endpoints ──────────────────────────────────────────────────────────
173
 
174
+ @app.post("/session/reset", summary="Create a stateful session", tags=["Sessions"])
175
+ async def session_reset(body: Dict[str, Any] = {},
176
+ x_session_id: Optional[str] = Header(None)):
177
+ """
178
+ Create a persistent session for multi-turn evaluation.
179
+ Pass `X-Session-Id` header to reuse an existing session.
180
+ Returns a `session_id` to use in subsequent calls.
181
+ """
182
  session_id = x_session_id or str(uuid.uuid4())
183
  if session_id in _sessions:
184
  _sessions[session_id].close()
 
188
  "enable_multi_turn", "enable_context_retrieval")})
189
  result = _safe_dict(obs)
190
  result["session_id"] = session_id
 
191
  return result
192
 
193
 
194
+ @app.post("/session/step", summary="Step in a session", tags=["Sessions"])
195
+ async def session_step(action_data: Dict[str, Any],
196
+ x_session_id: str = Header(...)):
197
+ """Submit an answer within a named session. Requires `X-Session-Id` header."""
 
 
198
  if x_session_id not in _sessions:
199
  raise HTTPException(status_code=404,
200
  detail=f"Session {x_session_id} not found. Call /session/reset first.")
 
206
  return result
207
 
208
 
209
+ @app.delete("/session", summary="Close a session", tags=["Sessions"])
210
+ async def close_session(x_session_id: str = Header(...)):
211
  """Close and clean up a session."""
212
  if x_session_id in _sessions:
213
  _sessions[x_session_id].close()
 
215
  return {"status": "closed", "session_id": x_session_id}
216
 
217
 
218
+ @app.get("/session/list", summary="List active sessions", tags=["Sessions"])
219
+ async def list_sessions():
220
  return {"active_sessions": len(_sessions), "session_ids": list(_sessions.keys())}
221
 
222
 
223
+ # ── Leaderboard ─────────────────────────────────────��──────────────────────────
224
+
225
+ @app.get("/leaderboard", summary="Model leaderboard", tags=["Leaderboard"])
226
+ async def get_leaderboard():
227
+ """
228
+ Returns ranked leaderboard of all submitted model evaluations.
229
+ Ranked by avg_reward descending.
230
+ """
231
+ if not _leaderboard:
232
+ return {"leaderboard": [], "total_models": 0,
233
+ "message": "No models submitted yet. Use POST /leaderboard/submit"}
234
+ ranked = sorted(_leaderboard.values(), key=lambda x: x.get("avg_reward", 0), reverse=True)
235
+ for i, entry in enumerate(ranked):
236
+ entry["rank"] = i + 1
237
+ return {
238
+ "leaderboard": ranked,
239
+ "total_models": len(ranked),
240
+ "last_updated": max(e.get("submitted_at", 0) for e in ranked),
241
+ }
242
+
243
+
244
+ @app.post("/leaderboard/submit", summary="Submit model evaluation results", tags=["Leaderboard"])
245
+ async def submit_to_leaderboard(data: Dict[str, Any]):
246
+ """
247
+ Submit your model's evaluation results to the leaderboard.
248
+
249
+ **Required fields:**
250
+ ```json
251
+ {
252
+ "model_name": "gpt-4o",
253
+ "avg_reward": 0.72,
254
+ "avg_accuracy": 0.81,
255
+ "hallucination_rate": 0.19,
256
+ "total_episodes": 10,
257
+ "total_steps": 100
258
+ }
259
+ ```
260
+ **Optional:** `organization`, `model_version`, `notes`
261
+ """
262
+ required = ["model_name", "avg_reward", "avg_accuracy",
263
+ "hallucination_rate", "total_episodes", "total_steps"]
264
+ missing = [f for f in required if f not in data]
265
+ if missing:
266
+ raise HTTPException(status_code=422,
267
+ detail=f"Missing required fields: {missing}")
268
+ model_name = data["model_name"]
269
+ _leaderboard[model_name] = {
270
+ "model_name": model_name,
271
+ "organization": data.get("organization", ""),
272
+ "model_version": data.get("model_version", ""),
273
+ "avg_reward": round(float(data["avg_reward"]), 4),
274
+ "avg_accuracy": round(float(data["avg_accuracy"]), 4),
275
+ "hallucination_rate": round(float(data["hallucination_rate"]), 4),
276
+ "total_episodes": int(data["total_episodes"]),
277
+ "total_steps": int(data["total_steps"]),
278
+ "notes": data.get("notes", ""),
279
+ "submitted_at": time.time(),
280
+ }
281
+ logger.info(f"Leaderboard submission: {model_name} reward={data['avg_reward']:.3f}")
282
+ return {"status": "submitted", "model_name": model_name,
283
+ "message": f"'{model_name}' added to leaderboard. View at /leaderboard"}
284
+
285
+
286
+ @app.delete("/leaderboard/{model_name}", summary="Remove from leaderboard", tags=["Leaderboard"])
287
+ async def remove_from_leaderboard(model_name: str):
288
+ """Remove a model entry from the leaderboard."""
289
+ if model_name not in _leaderboard:
290
+ raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found")
291
+ del _leaderboard[model_name]
292
+ return {"status": "removed", "model_name": model_name}
293
+
294
 
295
+ # ── Info & metrics ─────────────────────────────────────────────────────────────
296
+
297
+ @app.get("/health", summary="Health check", tags=["Info"])
298
  async def health():
299
+ return {"status": "healthy", "service": "HallucinationGuard-Env", "version": "3.0.0"}
300
+
301
+
302
+ @app.get("/environment/info", summary="Full environment spec", tags=["Info"])
303
+ async def env_info():
304
+ return {
305
+ "name": "HallucinationGuard-Env",
306
+ "version": "3.0.0",
307
+ "description": "Production RL environment for hallucination detection & prevention",
308
+ "datasets": {
309
+ "count": 13,
310
+ "total_examples": "50,000+",
311
+ "sources": [
312
+ "squad", "trivia_qa", "halueval", "truthful_qa",
313
+ "natural_questions", "hotpotqa", "boolq", "faithdial",
314
+ "fever", "arc", "openbookqa", "ms_marco", "coqa",
315
+ ],
316
+ },
317
+ "endpoints": {
318
+ "environment": ["/reset", "/step", "/state"],
319
+ "sessions": ["/session/reset", "/session/step", "/session/list", "/session"],
320
+ "leaderboard": ["/leaderboard", "/leaderboard/submit"],
321
+ "info": ["/health", "/environment/info", "/datasets", "/metrics"],
322
+ },
323
+ "difficulty_levels": ["beginner", "intermediate", "advanced", "expert"],
324
+ "hallucination_types": [
325
+ "fabricated_fact", "false_citation", "overconfident_wrong",
326
+ "context_drift", "numerical_fabrication", "entity_confusion",
327
+ ],
328
+ "reward_range": [-1.0, 1.0],
329
+ "supported_frameworks": ["OpenAI Gym", "OpenEnv", "custom Python", "REST API"],
330
+ }
331
+
332
+
333
+ @app.get("/datasets", summary="Dataset statistics", tags=["Info"])
334
+ async def dataset_info():
335
+ """Returns breakdown of loaded datasets by source, difficulty, and category."""
336
+ try:
337
+ env = _get_default_env()
338
+ stats = env.dataset_loader.get_statistics()
339
+ return {
340
+ "total_examples": stats.total_examples,
341
+ "by_source": stats.examples_by_source,
342
+ "by_difficulty": stats.examples_by_difficulty,
343
+ "by_category": stats.examples_by_category,
344
+ "avg_context_length": round(stats.average_context_length, 1),
345
+ "avg_question_length": round(stats.average_question_length, 1),
346
+ }
347
+ except Exception as e:
348
+ raise HTTPException(status_code=500, detail=str(e))
349
 
350
 
351
+ @app.get("/metrics", summary="Real-time metrics", tags=["Metrics"])
352
  async def get_metrics():
353
  try:
354
  return get_tracker().get_real_time_metrics()
 
356
  raise HTTPException(status_code=500, detail=str(e))
357
 
358
 
359
+ @app.get("/metrics/summary", summary="Metrics summary report", tags=["Metrics"])
360
  async def metrics_summary():
361
  try:
362
  return {"summary": get_tracker().generate_summary_report()}
 
364
  raise HTTPException(status_code=500, detail=str(e))
365
 
366
 
367
+ # ── Middleware ─────────────────────────────────────────────────────────────────
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
368
 
369
  @app.middleware("http")
370
  async def log_requests(request, call_next):
 
373
  return response
374
 
375
 
 
 
 
 
 
 
376
  if __name__ == "__main__":
377
  import uvicorn
378
+ uvicorn.run(app, host="0.0.0.0", port=7860)
server/dataset_loader.py CHANGED
The diff for this file is too large to render. See raw diff