msmaje commited on
Commit
676c241
·
verified ·
1 Parent(s): adf71b3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +143 -95
app.py CHANGED
@@ -1,98 +1,146 @@
1
- """
2
- Gradio Space for Human-AI Text Attribution (HATA) Model
3
- Detects whether text is human-written or AI-generated
4
- Supports multiple African languages
5
- """
6
-
7
- # --- Deterministic suppression of Gradio audio stack under Python 3.13 ---
8
  import os
9
- import sys
10
- import types
11
-
12
- os.environ["GRADIO_DISABLE_PYDUB"] = "1"
13
-
14
- # Provide stubs so that pydub cannot fail on audioop / pyaudioop
15
- if "audioop" not in sys.modules:
16
- sys.modules["audioop"] = types.ModuleType("audioop")
17
- if "pyaudioop" not in sys.modules:
18
- sys.modules["pyaudioop"] = types.ModuleType("pyaudioop")
19
-
20
- # Now it is safe to import Gradio and the rest of the stack
21
- import gradio as gr
22
- import torch
23
- import numpy as np
24
- from transformers import AutoTokenizer, AutoModelForSequenceClassification
25
-
26
- # ----------------------------------------------------------------------
27
- # Model configuration
28
- # ----------------------------------------------------------------------
29
- MODEL_NAME = "distilbert-base-multilingual-cased" # replace with your fine-tuned HATA checkpoint if available
30
- DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
31
-
32
- tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
33
- model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME, num_labels=2)
34
- model.to(DEVICE)
35
- model.eval()
36
-
37
- LABELS = ["Human-written", "AI-generated"]
38
-
39
- # ----------------------------------------------------------------------
40
- # Inference routine
41
- # ----------------------------------------------------------------------
42
- @torch.no_grad()
43
- def hata_predict(text: str):
44
- if not text or not text.strip():
45
- return {"Human-written": 0.0, "AI-generated": 0.0}
46
-
47
- inputs = tokenizer(
48
- text,
49
- return_tensors="pt",
50
- truncation=True,
51
- padding=True,
52
- max_length=512,
53
- ).to(DEVICE)
54
-
55
- outputs = model(**inputs)
56
- logits = outputs.logits.squeeze(0)
57
- probs = torch.softmax(logits, dim=-1).cpu().numpy()
58
-
59
- return {LABELS[i]: float(probs[i]) for i in range(len(LABELS))}
60
-
61
- # ----------------------------------------------------------------------
62
- # Gradio interface
63
- # ----------------------------------------------------------------------
64
- with gr.Blocks(title="Multilingual HATA System") as demo:
65
- gr.Markdown(
66
- """
67
- # Multilingual Human–AI Text Attribution (HATA)
68
-
69
- This system estimates whether an input passage is **human-written** or
70
- **AI-generated**, with a focus on multilingual and African-language use
71
- cases (e.g., Hausa, Yoruba, Igbo, Pidgin).
72
-
73
- The backend is a Transformer-based classifier fine-tuned for attribution.
74
- """
75
- )
76
-
77
- with gr.Row():
78
- with gr.Column(scale=3):
79
- text_input = gr.Textbox(
80
- label="Input Text",
81
- placeholder="Paste a paragraph in Hausa, Yoruba, Igbo, Pidgin, or English...",
82
- lines=8,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
  )
84
- submit_btn = gr.Button("Analyze")
85
- with gr.Column(scale=2):
86
- output = gr.Label(label="Attribution Probabilities")
87
-
88
- submit_btn.click(
89
- fn=hata_predict,
90
- inputs=text_input,
91
- outputs=output,
92
- )
93
-
94
- # ----------------------------------------------------------------------
95
- # Entry point
96
- # ----------------------------------------------------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
  if __name__ == "__main__":
98
- demo.launch()
 
1
+ # app.py
 
 
 
 
 
 
2
  import os
3
+ import math
4
+ import requests
5
+ from flask import Flask, request, jsonify
6
+ from flask_cors import CORS
7
+ from langdetect import detect
8
+
9
+ # -----------------------------------------------------------------------------
10
+ # Configuration
11
+ # -----------------------------------------------------------------------------
12
+ HF_API_URL = "https://api-inference.huggingface.co/models/YOUR_USERNAME/YOUR_MODEL"
13
+ HF_TOKEN = os.getenv("HF_TOKEN")
14
+
15
+ HEADERS = {
16
+ "Authorization": f"Bearer {HF_TOKEN}",
17
+ "Content-Type": "application/json"
18
+ }
19
+
20
+ app = Flask(__name__)
21
+ CORS(app)
22
+
23
+ # -----------------------------------------------------------------------------
24
+ # Utility Functions
25
+ # -----------------------------------------------------------------------------
26
+ def entropy(probs):
27
+ """Shannon entropy as epistemic uncertainty indicator."""
28
+ return -sum(p * math.log2(p) for p in probs if p > 0)
29
+
30
+ def normalize_labels(hf_output):
31
+ """
32
+ Normalize Hugging Face output into a stable schema.
33
+ Expected HF format:
34
+ [
35
+ {"label": "HUMAN", "score": 0.73},
36
+ {"label": "AI", "score": 0.27}
37
+ ]
38
+ """
39
+ result = {item["label"].lower(): float(item["score"]) for item in hf_output}
40
+ human_p = result.get("human", 0.0)
41
+ ai_p = result.get("ai", 0.0)
42
+
43
+ return human_p, ai_p
44
+
45
+ def hf_inference(text):
46
+ payload = {"inputs": text}
47
+ r = requests.post(HF_API_URL, headers=HEADERS, json=payload, timeout=30)
48
+ r.raise_for_status()
49
+ return r.json()
50
+
51
+ # -----------------------------------------------------------------------------
52
+ # Core Endpoint
53
+ # -----------------------------------------------------------------------------
54
+ @app.route("/analyze", methods=["POST"])
55
+ def analyze():
56
+ data = request.get_json()
57
+ text = data.get("text", "").strip()
58
+
59
+ if not text:
60
+ return jsonify({"error": "Empty input"}), 400
61
+
62
+ # 1. Language detection (supports linguistic auditing)
63
+ try:
64
+ language = detect(text)
65
+ except Exception:
66
+ language = "unknown"
67
+
68
+ # 2. Hugging Face inference
69
+ hf_raw = hf_inference(text)
70
+
71
+ if not isinstance(hf_raw, list):
72
+ return jsonify({"error": "Unexpected model response", "raw": hf_raw}), 500
73
+
74
+ human_p, ai_p = normalize_labels(hf_raw)
75
+
76
+ # 3. Decision
77
+ label = "Human" if human_p >= ai_p else "Machine"
78
+ confidence = max(human_p, ai_p)
79
+
80
+ # 4. Epistemic uncertainty
81
+ H = entropy([human_p, ai_p])
82
+
83
+ # 5. Explainability placeholder (XAI-ready schema)
84
+ explainability_stub = {
85
+ "method": "pending",
86
+ "note": (
87
+ "This model endpoint does not natively expose SHAP/LIME. "
88
+ "Post-hoc explainability must be computed locally using a "
89
+ "replicated model or proxy explainer."
90
+ ),
91
+ "token_attributions": []
92
+ }
93
+
94
+ # 6. Fairness metadata (for downstream auditing)
95
+ fairness_context = {
96
+ "language": language,
97
+ "human_probability": human_p,
98
+ "ai_probability": ai_p,
99
+ "entropy": H
100
+ }
101
+
102
+ response = {
103
+ "prediction": {
104
+ "label": label,
105
+ "confidence": round(confidence, 4)
106
+ },
107
+ "probabilities": {
108
+ "human": round(human_p, 4),
109
+ "machine": round(ai_p, 4)
110
+ },
111
+ "uncertainty": {
112
+ "entropy": round(H, 4),
113
+ "interpretation": (
114
+ "High entropy indicates epistemic ambiguity; "
115
+ "classification should be treated cautiously."
116
  )
117
+ },
118
+ "linguistic_context": {
119
+ "detected_language": language
120
+ },
121
+ "explainability": explainability_stub,
122
+ "fairness_audit_fields": fairness_context
123
+ }
124
+
125
+ return jsonify(response)
126
+
127
+ # -----------------------------------------------------------------------------
128
+ # Health Check
129
+ # -----------------------------------------------------------------------------
130
+ @app.route("/", methods=["GET"])
131
+ def index():
132
+ return jsonify({
133
+ "system": "HATA API",
134
+ "capabilities": [
135
+ "Human vs AI classification",
136
+ "Probability calibration",
137
+ "Uncertainty estimation",
138
+ "Language-aware auditing",
139
+ "Explainability-ready schema",
140
+ "Fairness instrumentation"
141
+ ]
142
+ })
143
+
144
+ # -----------------------------------------------------------------------------
145
  if __name__ == "__main__":
146
+ app.run(host="0.0.0.0", port=5000, debug=True)