petter2025 commited on
Commit
11c85bb
·
verified ·
1 Parent(s): d646702

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +234 -106
app.py CHANGED
@@ -4,86 +4,132 @@ import json
4
  import logging
5
  import traceback
6
  import os
 
7
  import numpy as np
8
  from datetime import datetime
9
- from transformers import pipeline, set_seed
10
- import torch
 
 
 
 
11
 
12
- # Import our components
13
  from agentic_reliability_framework.runtime.engine import EnhancedReliabilityEngine
14
  from hallucination_detective import HallucinationDetectiveAgent
15
  from memory_drift_diagnostician import MemoryDriftDiagnosticianAgent
 
 
16
  from ai_event import AIEvent
17
  from ai_risk_engine import AIRiskEngine
18
  from nli_detector import NLIDetector
 
19
 
20
  logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
21
  logger = logging.getLogger(__name__)
22
 
23
- # Initialize infrastructure engine (optional)
 
 
24
  try:
25
  logger.info("Initializing EnhancedReliabilityEngine...")
26
  engine = EnhancedReliabilityEngine()
27
- logger.info("Engine initialized successfully.")
28
  except Exception as e:
29
- logger.error(f"Failed to initialize engine: {e}\n{traceback.format_exc()}")
30
  engine = None
31
 
32
- # Load generative model (small autoregressive)
 
 
33
  gen_model_name = "microsoft/DialoGPT-small"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  try:
35
- generator = pipeline('text-generation', model=gen_model_name, device=0 if torch.cuda.is_available() else -1)
36
- logger.info(f"Generator {gen_model_name} loaded.")
 
 
 
 
37
  except Exception as e:
38
- logger.error(f"Failed to load generator: {e}")
39
- generator = None
40
 
41
- # Load NLI detector
42
- nli_detector = NLIDetector()
 
 
 
 
 
 
 
 
 
 
 
43
 
 
44
  # AI agents
 
45
  hallucination_detective = HallucinationDetectiveAgent(nli_detector=nli_detector)
46
  memory_drift_diagnostician = MemoryDriftDiagnosticianAgent()
 
 
47
 
48
- # AI risk engine
 
 
49
  ai_risk_engine = AIRiskEngine()
50
 
51
- # In‑memory storage for last event to attach feedback
52
- last_ai_event = None
53
- last_ai_category = None
54
-
55
- async def generate_response(prompt: str, max_length: int = 100) -> tuple:
56
- """Generate response using the small autoregressive model."""
57
- if generator is None:
58
- return "[Model not loaded]", 0.0, "Model loading failed"
59
- try:
60
- loop = asyncio.get_event_loop()
61
- # We need to compute confidence; text-generation pipeline returns text but not logits.
62
- # For simplicity, we'll set confidence based on a heuristic (e.g., generation length?).
63
- # Alternatively, use a model that returns probabilities.
64
- # Let's use a simple placeholder: confidence = 0.8 if generation succeeds.
65
- # In practice, we'd need to access logits.
66
- result = await loop.run_in_executor(
67
- None,
68
- lambda: generator(prompt, max_new_tokens=max_length, return_full_text=False)
69
  )
70
- response = result[0]['generated_text']
71
- # Placeholder confidence
72
- confidence = 0.8
73
- return response, confidence, ""
74
- except Exception as e:
75
- logger.error(f"Generation error: {e}")
76
- return "", 0.0, str(e)
 
 
 
77
 
78
- async def analyze_ai(task_type, prompt):
79
- global last_ai_event, last_ai_category
 
 
80
  try:
81
- # Generate response
82
- response, confidence, error = await generate_response(prompt)
83
- if error:
84
- return json.dumps({"error": error}, indent=2)
85
-
86
- # Create AIEvent
87
  event = AIEvent(
88
  timestamp=datetime.utcnow(),
89
  component="ai",
@@ -99,91 +145,173 @@ async def analyze_ai(task_type, prompt):
99
  prompt=prompt,
100
  response=response,
101
  response_length=len(response),
102
- confidence=confidence,
103
  perplexity=None,
104
- retrieval_scores=None,
105
  user_feedback=None,
106
  latency_ms=0
107
  )
108
- last_ai_event = event
109
- last_ai_category = task_type
110
-
111
- # Run agents
112
  hallu_result = await hallucination_detective.analyze(event)
113
  drift_result = await memory_drift_diagnostician.analyze(event)
114
-
115
- # Get current risk metrics
116
  risk_metrics = ai_risk_engine.risk_score(task_type)
117
-
118
- result = {
119
  "response": response,
120
- "confidence": confidence,
 
 
121
  "hallucination_detection": hallu_result,
122
  "memory_drift_detection": drift_result,
123
  "risk_metrics": risk_metrics
124
  }
125
- return json.dumps(result, indent=2)
126
  except Exception as e:
127
- logger.error(f"AI analysis error: {e}\n{traceback.format_exc()}")
128
- return json.dumps({"error": str(e), "traceback": traceback.format_exc()}, indent=2)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
129
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
130
  def feedback(thumbs_up: bool):
131
- """Handle user feedback to update Beta priors."""
132
- global last_ai_category, last_ai_event
133
- if last_ai_category is None:
134
  return "No previous analysis to rate."
135
- ai_risk_engine.update_outcome(last_ai_category, success=thumbs_up)
136
- # Optionally, also update the event with feedback
137
- if last_ai_event:
138
- last_ai_event.user_feedback = thumbs_up
139
- return f"Feedback recorded: {'👍' if thumbs_up else '👎'} for {last_ai_category}."
140
 
141
- # Build the Gradio interface
 
 
142
  with gr.Blocks(title="ARF v4 – AI Reliability Lab", theme="soft") as demo:
143
- gr.Markdown("# 🧠 ARF v4 – AI Reliability Lab\n**Detect hallucinations and drift in generative AI**")
144
 
145
- with gr.Row():
146
- with gr.Column():
147
- task_type = gr.Dropdown(
148
- choices=["chat", "code", "summary"],
149
- value="chat",
150
- label="Task Type"
151
- )
152
- prompt = gr.Textbox(
153
- label="Prompt",
154
- value="What is the capital of France?",
155
- lines=3
156
- )
157
- analyze_btn = gr.Button("Analyze", variant="primary")
158
- with gr.Column():
159
- output = gr.JSON(label="Analysis Result")
 
 
160
 
161
  with gr.Row():
162
- feedback_btn_up = gr.Button("👍 Correct")
163
- feedback_btn_down = gr.Button("👎 Incorrect")
164
  feedback_msg = gr.Textbox(label="Feedback", interactive=False)
165
 
166
- analyze_btn.click(
167
- fn=analyze_ai,
168
- inputs=[task_type, prompt],
169
- outputs=output
 
170
  )
171
- feedback_btn_up.click(
172
- fn=lambda: feedback(True),
173
- outputs=feedback_msg
 
174
  )
175
- feedback_btn_down.click(
176
- fn=lambda: feedback(False),
177
- outputs=feedback_msg
 
178
  )
179
-
180
- gr.Markdown("""
181
- ---
182
- - **Model**: `microsoft/DialoGPT-small` (autoregressive, 117M params)
183
- - **NLI Detector**: `typeform/distilroberta-base-mnli` (82M params)
184
- - **Risk engine**: Beta conjugate priors per task category
185
- - **Feedback** updates the posterior distribution
186
- """)
187
 
188
  if __name__ == "__main__":
189
  demo.launch(server_name="0.0.0.0", server_port=7860)
 
4
  import logging
5
  import traceback
6
  import os
7
+ import torch
8
  import numpy as np
9
  from datetime import datetime
10
+ from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
11
+ from sentence_transformers import SentenceTransformer, util
12
+ from diffusers import StableDiffusionPipeline
13
+ import librosa
14
+ import soundfile as sf
15
+ import tempfile
16
 
17
+ # ARF components
18
  from agentic_reliability_framework.runtime.engine import EnhancedReliabilityEngine
19
  from hallucination_detective import HallucinationDetectiveAgent
20
  from memory_drift_diagnostician import MemoryDriftDiagnosticianAgent
21
+ from image_detector import ImageQualityDetector
22
+ from audio_detector import AudioQualityDetector
23
  from ai_event import AIEvent
24
  from ai_risk_engine import AIRiskEngine
25
  from nli_detector import NLIDetector
26
+ from retrieval import SimpleRetriever
27
 
28
  logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
29
  logger = logging.getLogger(__name__)
30
 
31
+ # ----------------------------------------------------------------------
32
+ # Infrastructure engine (optional)
33
+ # ----------------------------------------------------------------------
34
  try:
35
  logger.info("Initializing EnhancedReliabilityEngine...")
36
  engine = EnhancedReliabilityEngine()
 
37
  except Exception as e:
38
+ logger.error(f"Engine init failed: {e}")
39
  engine = None
40
 
41
+ # ----------------------------------------------------------------------
42
+ # Generative model for text (DialoGPT-small)
43
+ # ----------------------------------------------------------------------
44
  gen_model_name = "microsoft/DialoGPT-small"
45
+ tokenizer = AutoTokenizer.from_pretrained(gen_model_name)
46
+ model = AutoModelForCausalLM.from_pretrained(gen_model_name)
47
+ logger.info(f"Generator {gen_model_name} loaded.")
48
+
49
+ # ----------------------------------------------------------------------
50
+ # NLI detector
51
+ # ----------------------------------------------------------------------
52
+ nli_detector = NLIDetector()
53
+
54
+ # ----------------------------------------------------------------------
55
+ # Sentence‑Transformer retriever
56
+ # ----------------------------------------------------------------------
57
+ retriever = SimpleRetriever()
58
+ logger.info("Retriever loaded.")
59
+
60
+ # ----------------------------------------------------------------------
61
+ # Image generation (tiny model for demo)
62
+ # ----------------------------------------------------------------------
63
  try:
64
+ image_pipe = StableDiffusionPipeline.from_pretrained(
65
+ "hf-internal-testing/tiny-stable-diffusion-torch"
66
+ )
67
+ if not torch.cuda.is_available():
68
+ image_pipe.to("cpu")
69
+ logger.info("Image pipeline loaded.")
70
  except Exception as e:
71
+ logger.error(f"Image pipeline failed: {e}")
72
+ image_pipe = None
73
 
74
+ # ----------------------------------------------------------------------
75
+ # Audio transcription (Whisper tiny)
76
+ # ----------------------------------------------------------------------
77
+ try:
78
+ audio_pipe = pipeline(
79
+ "automatic-speech-recognition",
80
+ model="openai/whisper-tiny.en",
81
+ device=0 if torch.cuda.is_available() else -1
82
+ )
83
+ logger.info("Audio pipeline loaded.")
84
+ except Exception as e:
85
+ logger.error(f"Audio pipeline failed: {e}")
86
+ audio_pipe = None
87
 
88
+ # ----------------------------------------------------------------------
89
  # AI agents
90
+ # ----------------------------------------------------------------------
91
  hallucination_detective = HallucinationDetectiveAgent(nli_detector=nli_detector)
92
  memory_drift_diagnostician = MemoryDriftDiagnosticianAgent()
93
+ image_quality_detector = ImageQualityDetector()
94
+ audio_quality_detector = AudioQualityDetector()
95
 
96
+ # ----------------------------------------------------------------------
97
+ # Bayesian risk engine
98
+ # ----------------------------------------------------------------------
99
  ai_risk_engine = AIRiskEngine()
100
 
101
+ # ----------------------------------------------------------------------
102
+ # Generation helper with log probabilities
103
+ # ----------------------------------------------------------------------
104
+ def generate_with_logprobs(prompt, max_new_tokens=100):
105
+ inputs = tokenizer(prompt, return_tensors="pt")
106
+ with torch.no_grad():
107
+ outputs = model.generate(
108
+ **inputs,
109
+ max_new_tokens=max_new_tokens,
110
+ return_dict_in_generate=True,
111
+ output_scores=True
 
 
 
 
 
 
 
112
  )
113
+ scores = outputs.scores
114
+ log_probs = [torch.log_softmax(score, dim=-1) for score in scores]
115
+ generated_ids = outputs.sequences[0][inputs['input_ids'].shape[1]:]
116
+ token_log_probs = []
117
+ for i, lp in enumerate(log_probs):
118
+ token_id = generated_ids[i]
119
+ token_log_probs.append(lp[0, token_id].item())
120
+ avg_log_prob = sum(token_log_probs) / len(token_log_probs) if token_log_probs else 0.0
121
+ generated_text = tokenizer.decode(generated_ids, skip_special_tokens=True)
122
+ return generated_text, avg_log_prob
123
 
124
+ # ----------------------------------------------------------------------
125
+ # Task handlers
126
+ # ----------------------------------------------------------------------
127
+ async def handle_text(task_type, prompt):
128
  try:
129
+ response, avg_log_prob = generate_with_logprobs(prompt)
130
+ # Get retrieval score
131
+ retrieval_score = retriever.get_similarity(prompt)
132
+ # Create event
 
 
133
  event = AIEvent(
134
  timestamp=datetime.utcnow(),
135
  component="ai",
 
145
  prompt=prompt,
146
  response=response,
147
  response_length=len(response),
148
+ confidence=float(np.exp(avg_log_prob)), # convert to probability scale
149
  perplexity=None,
150
+ retrieval_scores=[retrieval_score],
151
  user_feedback=None,
152
  latency_ms=0
153
  )
154
+ # Analyze
 
 
 
155
  hallu_result = await hallucination_detective.analyze(event)
156
  drift_result = await memory_drift_diagnostician.analyze(event)
 
 
157
  risk_metrics = ai_risk_engine.risk_score(task_type)
158
+ return {
 
159
  "response": response,
160
+ "avg_log_prob": avg_log_prob,
161
+ "confidence": event.confidence,
162
+ "retrieval_score": retrieval_score,
163
  "hallucination_detection": hallu_result,
164
  "memory_drift_detection": drift_result,
165
  "risk_metrics": risk_metrics
166
  }
 
167
  except Exception as e:
168
+ logger.error(f"Text task error: {e}")
169
+ return {"error": str(e)}
170
+
171
+ async def handle_image(prompt):
172
+ if image_pipe is None:
173
+ return {"error": "Image model not loaded"}
174
+ try:
175
+ import time
176
+ start = time.time()
177
+ image = image_pipe(prompt, num_inference_steps=2).images[0] # tiny steps for speed
178
+ gen_time = time.time() - start
179
+ # Mock retrieval score (you could use CLIP similarity)
180
+ retrieval_score = retriever.get_similarity(prompt)
181
+ event = AIEvent(
182
+ timestamp=datetime.utcnow(),
183
+ component="image",
184
+ service_mesh="ai",
185
+ latency_p99=0,
186
+ error_rate=0.0,
187
+ throughput=1,
188
+ cpu_util=None,
189
+ memory_util=None,
190
+ action_category="image",
191
+ model_name="tiny-sd",
192
+ model_version="latest",
193
+ prompt=prompt,
194
+ response="", # image not text
195
+ response_length=0,
196
+ confidence=1.0 / (gen_time + 1), # heuristic
197
+ perplexity=None,
198
+ retrieval_scores=[retrieval_score, gen_time],
199
+ user_feedback=None,
200
+ latency_ms=gen_time * 1000
201
+ )
202
+ quality_result = await image_quality_detector.analyze(event)
203
+ return {
204
+ "image": image,
205
+ "generation_time": gen_time,
206
+ "retrieval_score": retrieval_score,
207
+ "quality_detection": quality_result
208
+ }
209
+ except Exception as e:
210
+ logger.error(f"Image task error: {e}")
211
+ return {"error": str(e)}
212
 
213
+ async def handle_audio(audio_file):
214
+ if audio_pipe is None:
215
+ return {"error": "Audio model not loaded"}
216
+ try:
217
+ # Load audio (Gradio provides file path)
218
+ audio, sr = librosa.load(audio_file, sr=16000)
219
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
220
+ sf.write(tmp.name, audio, sr)
221
+ result = audio_pipe(tmp.name, return_timestamps=False)
222
+ text = result["text"]
223
+ # Whisper does not output log probs easily; we'll use a placeholder
224
+ avg_log_prob = -2.0 # placeholder
225
+ event = AIEvent(
226
+ timestamp=datetime.utcnow(),
227
+ component="audio",
228
+ service_mesh="ai",
229
+ latency_p99=0,
230
+ error_rate=0.0,
231
+ throughput=1,
232
+ cpu_util=None,
233
+ memory_util=None,
234
+ action_category="audio",
235
+ model_name="whisper-tiny.en",
236
+ model_version="latest",
237
+ prompt="", # audio file path
238
+ response=text,
239
+ response_length=len(text),
240
+ confidence=float(np.exp(avg_log_prob)),
241
+ perplexity=None,
242
+ retrieval_scores=[avg_log_prob],
243
+ user_feedback=None,
244
+ latency_ms=0
245
+ )
246
+ quality_result = await audio_quality_detector.analyze(event)
247
+ return {
248
+ "transcription": text,
249
+ "avg_log_prob": avg_log_prob,
250
+ "confidence": event.confidence,
251
+ "quality_detection": quality_result
252
+ }
253
+ except Exception as e:
254
+ logger.error(f"Audio task error: {e}")
255
+ return {"error": str(e)}
256
+
257
+ # ----------------------------------------------------------------------
258
+ # Feedback handling
259
+ # ----------------------------------------------------------------------
260
+ last_event_category = None
261
  def feedback(thumbs_up: bool):
262
+ global last_event_category
263
+ if last_event_category is None:
 
264
  return "No previous analysis to rate."
265
+ ai_risk_engine.update_outcome(last_event_category, success=thumbs_up)
266
+ return f"Feedback recorded: {'👍' if thumbs_up else '👎'} for {last_event_category}."
 
 
 
267
 
268
+ # ----------------------------------------------------------------------
269
+ # Gradio UI
270
+ # ----------------------------------------------------------------------
271
  with gr.Blocks(title="ARF v4 – AI Reliability Lab", theme="soft") as demo:
272
+ gr.Markdown("# 🧠 ARF v4 – AI Reliability Lab\n**Detect hallucinations, drift, and failures across text, image, and audio**")
273
 
274
+ with gr.Tabs():
275
+ with gr.TabItem("Text Generation"):
276
+ text_task = gr.Dropdown(["chat", "code", "summary"], value="chat", label="Task")
277
+ text_prompt = gr.Textbox(label="Prompt", value="What is the capital of France?")
278
+ text_btn = gr.Button("Generate")
279
+ text_output = gr.JSON(label="Analysis")
280
+
281
+ with gr.TabItem("Image Generation"):
282
+ img_prompt = gr.Textbox(label="Prompt", value="A cat wearing a hat")
283
+ img_btn = gr.Button("Generate")
284
+ img_output = gr.Image(label="Generated Image")
285
+ img_json = gr.JSON(label="Analysis")
286
+
287
+ with gr.TabItem("Audio Transcription"):
288
+ audio_input = gr.Audio(type="filepath", label="Upload audio file")
289
+ audio_btn = gr.Button("Transcribe")
290
+ audio_output = gr.JSON(label="Analysis")
291
 
292
  with gr.Row():
293
+ feedback_up = gr.Button("👍 Correct")
294
+ feedback_down = gr.Button("👎 Incorrect")
295
  feedback_msg = gr.Textbox(label="Feedback", interactive=False)
296
 
297
+ # Wire up events
298
+ text_btn.click(
299
+ fn=lambda task, p: asyncio.run(handle_text(task, p)),
300
+ inputs=[text_task, text_prompt],
301
+ outputs=text_output
302
  )
303
+ img_btn.click(
304
+ fn=lambda p: asyncio.run(handle_image(p)),
305
+ inputs=img_prompt,
306
+ outputs=[img_output, img_json]
307
  )
308
+ audio_btn.click(
309
+ fn=lambda f: asyncio.run(handle_audio(f)),
310
+ inputs=audio_input,
311
+ outputs=audio_output
312
  )
313
+ feedback_up.click(fn=lambda: feedback(True), outputs=feedback_msg)
314
+ feedback_down.click(fn=lambda: feedback(False), outputs=feedback_msg)
 
 
 
 
 
 
315
 
316
  if __name__ == "__main__":
317
  demo.launch(server_name="0.0.0.0", server_port=7860)