Pushkar27 commited on
Commit
6824a07
Β·
1 Parent(s): d923478

CRITICAL: Remove all escaped underscores from YAML metadata

Browse files
Files changed (1) hide show
  1. README.md +9 -17
README.md CHANGED
@@ -14,7 +14,7 @@ tags:
14
  - nlp
15
  - pragmatics
16
  datasets:
17
- - topical_chat
18
  metrics:
19
  - f1
20
  - precision
@@ -30,7 +30,7 @@ model-index:
30
  name: Multi-Label Gricean Maxim Violation Detection
31
  dataset:
32
  name: Topical-Chat (GriceBench held-out split, N=1000)
33
- type: topical_chat
34
  split: test
35
  metrics:
36
  - type: f1
@@ -60,9 +60,9 @@ model-index:
60
  [![HuggingFace](https://img.shields.io/badge/πŸ€—-GriceBench-yellow)](https://huggingface.co/Pushkar27)
61
  [![Python 3.8+](https://img.shields.io/badge/python-3.8+-blue.svg)](https://www.python.org/downloads/)
62
 
63
- **Part of the GriceBench system** β€”
64
- [GitHub](https://github.com/PushkarPrabhath27/Research-Model) |
65
- [πŸ”§ Repair Model](https://huggingface.co/Pushkar27/GriceBench-Repair) |
66
  [⚑ DPO Generator](https://huggingface.co/Pushkar27/GriceBench-DPO)
67
 
68
  </div>
@@ -92,12 +92,11 @@ import torch.nn as nn
92
  import json
93
  from transformers import AutoTokenizer, AutoModel
94
 
95
- # ── Define model architecture (must match training) ─────────────────────────
96
  class MaximDetector(nn.Module):
97
  def __init__(self, model_name="microsoft/deberta-v3-base", num_maxims=4):
98
  super().__init__()
99
  self.encoder = AutoModel.from_pretrained(model_name)
100
- hidden = self.encoder.config.hidden_size # 768
101
  self.classifiers = nn.ModuleList([
102
  nn.Sequential(
103
  nn.Dropout(0.15),
@@ -114,8 +113,6 @@ class MaximDetector(nn.Module):
114
  cls = outputs.last_hidden_state[:, 0, :]
115
  return torch.cat([head(cls) for head in self.classifiers], dim=1)
116
 
117
- # ── Load model and calibration ──────────────────────────────────────────────
118
- # Download pytorch_model.pt and temperatures.json from this repo first
119
  tokenizer = AutoTokenizer.from_pretrained("microsoft/deberta-v3-base")
120
  model = MaximDetector()
121
  state_dict = torch.load("pytorch_model.pt", map_location="cpu")
@@ -125,9 +122,8 @@ model.eval()
125
  with open("temperatures.json") as f:
126
  temperatures = json.load(f)
127
 
128
- # ── Detect violations ───────────────────────────────────────────────────────
129
  def detect_violations(context: str, response: str, evidence: str = "") -> dict:
130
- input_text = f"Context: {context}\n\nEvidence: {evidence}\n\nResponse: {response}"
131
  inputs = tokenizer(
132
  input_text, return_tensors="pt",
133
  max_length=512, truncation=True, padding=True
@@ -142,7 +138,7 @@ def detect_violations(context: str, response: str, evidence: str = "") -> dict:
142
  ]
143
 
144
  with torch.no_grad():
145
- logits = model(**inputs) # Shape: [1, 4]
146
 
147
  probs, violations = {}, {}
148
  for i, (maxim, temp) in enumerate(zip(maxim_names, temp_values)):
@@ -156,16 +152,12 @@ def detect_violations(context: str, response: str, evidence: str = "") -> dict:
156
  "is_cooperative": not any(violations.values())
157
  }
158
 
159
- # ── Example ─────────────────────────────────────────────────────────────────
160
  result = detect_violations(
161
  context="What do you think about the latest developments in AI?",
162
- response="Yes.", # Too short β€” Quantity violation
163
  evidence="AI has seen rapid advancement in large language models during 2024-2025."
164
  )
165
  print(result)
166
- # {'violations': {'quantity': True, 'quality': False, 'relation': False, 'manner': False},
167
- # 'probabilities': {'quantity': 0.97, 'quality': 0.02, 'relation': 0.03, 'manner': 0.11},
168
- # 'is_cooperative': False}
169
  ```
170
 
171
  ---
 
14
  - nlp
15
  - pragmatics
16
  datasets:
17
+ - topical-chat
18
  metrics:
19
  - f1
20
  - precision
 
30
  name: Multi-Label Gricean Maxim Violation Detection
31
  dataset:
32
  name: Topical-Chat (GriceBench held-out split, N=1000)
33
+ type: topical-chat
34
  split: test
35
  metrics:
36
  - type: f1
 
60
  [![HuggingFace](https://img.shields.io/badge/πŸ€—-GriceBench-yellow)](https://huggingface.co/Pushkar27)
61
  [![Python 3.8+](https://img.shields.io/badge/python-3.8+-blue.svg)](https://www.python.org/downloads/)
62
 
63
+ **Part of the GriceBench system** β€”
64
+ [GitHub](https://github.com/PushkarPrabhath27/Research-Model) |
65
+ [πŸ”§ Repair Model](https://huggingface.co/Pushkar27/GriceBench-Repair) |
66
  [⚑ DPO Generator](https://huggingface.co/Pushkar27/GriceBench-DPO)
67
 
68
  </div>
 
92
  import json
93
  from transformers import AutoTokenizer, AutoModel
94
 
 
95
  class MaximDetector(nn.Module):
96
  def __init__(self, model_name="microsoft/deberta-v3-base", num_maxims=4):
97
  super().__init__()
98
  self.encoder = AutoModel.from_pretrained(model_name)
99
+ hidden = self.encoder.config.hidden_size
100
  self.classifiers = nn.ModuleList([
101
  nn.Sequential(
102
  nn.Dropout(0.15),
 
113
  cls = outputs.last_hidden_state[:, 0, :]
114
  return torch.cat([head(cls) for head in self.classifiers], dim=1)
115
 
 
 
116
  tokenizer = AutoTokenizer.from_pretrained("microsoft/deberta-v3-base")
117
  model = MaximDetector()
118
  state_dict = torch.load("pytorch_model.pt", map_location="cpu")
 
122
  with open("temperatures.json") as f:
123
  temperatures = json.load(f)
124
 
 
125
  def detect_violations(context: str, response: str, evidence: str = "") -> dict:
126
+ input_text = f"Context: {context}\nEvidence: {evidence}\nResponse: {response}"
127
  inputs = tokenizer(
128
  input_text, return_tensors="pt",
129
  max_length=512, truncation=True, padding=True
 
138
  ]
139
 
140
  with torch.no_grad():
141
+ logits = model(**inputs)
142
 
143
  probs, violations = {}, {}
144
  for i, (maxim, temp) in enumerate(zip(maxim_names, temp_values)):
 
152
  "is_cooperative": not any(violations.values())
153
  }
154
 
 
155
  result = detect_violations(
156
  context="What do you think about the latest developments in AI?",
157
+ response="Yes.",
158
  evidence="AI has seen rapid advancement in large language models during 2024-2025."
159
  )
160
  print(result)
 
 
 
161
  ```
162
 
163
  ---