PooryaPiroozfar commited on
Commit
e8ca7f8
·
verified ·
1 Parent(s): 51f72ca

Upload 4 files

Browse files
Files changed (4) hide show
  1. Dockerfile +21 -0
  2. app.py +317 -0
  3. frame_triples.xlsx +0 -0
  4. requirements.txt +8 -0
Dockerfile ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.10-slim
2
+
3
+ ENV PYTHONDONTWRITEBYTECODE=1
4
+ ENV PYTHONUNBUFFERED=1
5
+
6
+ WORKDIR /app
7
+
8
+ RUN apt-get update && apt-get install -y \
9
+ git \
10
+ build-essential \
11
+ && rm -rf /var/lib/apt/lists/*
12
+
13
+ COPY requirements.txt .
14
+ RUN pip install --upgrade pip \
15
+ && pip install --no-cache-dir -r requirements.txt
16
+
17
+ COPY . .
18
+
19
+ EXPOSE 7860
20
+
21
+ CMD ["python", "app.py"]
app.py ADDED
@@ -0,0 +1,317 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ import os
4
+ import json
5
+ import re
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ import numpy as np
10
+ import pandas as pd
11
+ import gradio as gr
12
+
13
+ from transformers import (
14
+ AutoTokenizer,
15
+ AutoModel,
16
+ AutoModelForTokenClassification
17
+ )
18
+ from huggingface_hub import snapshot_download
19
+
20
+ # -------------------------
21
+ # تنظیمات کلی
22
+ # -------------------------
23
+ device = torch.device("cpu")
24
+
25
+ FRAME_DET_REPO = "PooryaPiroozfar/frame-detection-parsbert"
26
+ FE_REPO = "PooryaPiroozfar/srl-frame-elements-parsbert"
27
+
28
+ FRAME_DET_DIR = "models/frame_detection"
29
+ FE_BASE_DIR = "models/frame_elements"
30
+
31
+ TRIPLES_PATH = "frame_triples.xlsx"
32
+ THRESHOLD = 0.2
33
+
34
+ frame_names = [
35
+ "Activity_finish","Activity_start","Aging","Attaching","Attempt",
36
+ "Becoming","Being_born","Borrowing","Causation","Chatting",
37
+ "Choosing","Closure","Clothing","Cutting","Damaging","Desiring","Discussion",
38
+ "Emphasizing","Food","Installing","Locating","Memory","Morality_evaluation",
39
+ "Motion","Offering","Practice","Project","Publishing","Religious_belief",
40
+ "Removing","Request","Residence","Sharing","Taking","Telling","Travel",
41
+ "Using","Visiting","Waiting","Work"
42
+ ]
43
+
44
+ # -------------------------
45
+ # دانلود مدل‌ها
46
+ # -------------------------
47
+ if not os.path.exists(FRAME_DET_DIR):
48
+ snapshot_download(repo_id=FRAME_DET_REPO, local_dir=FRAME_DET_DIR)
49
+
50
+ if not os.path.exists(FE_BASE_DIR):
51
+ snapshot_download(repo_id=FE_REPO, local_dir=FE_BASE_DIR)
52
+
53
+ # -------------------------
54
+ # Sentence Encoder (ParsBERT)
55
+ # -------------------------
56
+ encoder_name = "HooshvareLab/bert-base-parsbert-uncased"
57
+ sent_tokenizer = AutoTokenizer.from_pretrained(encoder_name)
58
+ sent_encoder = AutoModel.from_pretrained(encoder_name).to(device)
59
+ sent_encoder.eval()
60
+
61
+ def get_embedding(text):
62
+ inputs = sent_tokenizer(
63
+ text,
64
+ return_tensors="pt",
65
+ truncation=True,
66
+ padding=True,
67
+ max_length=128
68
+ ).to(device)
69
+
70
+ with torch.no_grad():
71
+ outputs = sent_encoder(**inputs)
72
+
73
+ token_embeddings = outputs.last_hidden_state
74
+ mask = inputs["attention_mask"].unsqueeze(-1).expand(token_embeddings.size()).float()
75
+ summed = torch.sum(token_embeddings * mask, dim=1)
76
+ lengths = torch.clamp(mask.sum(dim=1), min=1e-9)
77
+
78
+ return (summed / lengths).squeeze(0)
79
+
80
+ # -------------------------
81
+ # Frame Detection Model
82
+ # -------------------------
83
+ class FrameSimilarityModel(nn.Module):
84
+ def __init__(self, emb_dim, frame_emb_init):
85
+ super().__init__()
86
+ self.proj = nn.Linear(emb_dim, emb_dim)
87
+ self.frame_embeddings = nn.Parameter(
88
+ torch.tensor(frame_emb_init, dtype=torch.float32)
89
+ )
90
+
91
+ def forward(self, sent_emb):
92
+ sent_proj = F.normalize(self.proj(sent_emb), dim=-1)
93
+ frames = F.normalize(self.frame_embeddings, dim=-1)
94
+ return torch.matmul(sent_proj, frames.T)
95
+
96
+ frame_embs = np.load(os.path.join(FRAME_DET_DIR, "trained_frame_embeddings.npy"))
97
+
98
+ frame_model = FrameSimilarityModel(
99
+ emb_dim=768,
100
+ frame_emb_init=frame_embs
101
+ ).to(device)
102
+
103
+ state_dict = torch.load(
104
+ os.path.join(FRAME_DET_DIR, "best_frame_margin_model.pt"),
105
+ map_location="cpu"
106
+ )
107
+ frame_model.load_state_dict(state_dict)
108
+ frame_model.eval()
109
+
110
+ def predict_frame(sentence):
111
+ emb = get_embedding(sentence).unsqueeze(0)
112
+ with torch.no_grad():
113
+ sims = frame_model(emb)
114
+ max_sim, idx = torch.max(sims, dim=1)
115
+
116
+ if max_sim.item() < THRESHOLD:
117
+ return None, max_sim.item()
118
+
119
+ return frame_names[idx.item()], max_sim.item()
120
+
121
+ # -------------------------
122
+ # Frame Elements (SRL)
123
+ # -------------------------
124
+ def predict_frame_elements(sentence, frame_name):
125
+ frame_dir = os.path.join(FE_BASE_DIR, frame_name)
126
+ if not os.path.exists(frame_dir):
127
+ return []
128
+
129
+ with open(os.path.join(frame_dir, "label2id.json"), encoding="utf-8") as f:
130
+ label2id = json.load(f)
131
+
132
+ id2label = {int(v): k for k, v in label2id.items()}
133
+
134
+ tokenizer = AutoTokenizer.from_pretrained(frame_dir)
135
+ model = AutoModelForTokenClassification.from_pretrained(
136
+ frame_dir,
137
+ num_labels=len(label2id),
138
+ id2label=id2label,
139
+ label2id=label2id
140
+ ).to(device)
141
+ model.eval()
142
+
143
+ inputs = tokenizer(sentence, return_tensors="pt", truncation=True, max_length=128)
144
+
145
+ with torch.no_grad():
146
+ outputs = model(**inputs)
147
+
148
+ preds = torch.argmax(outputs.logits, dim=-1).squeeze(0).numpy()
149
+ tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"].squeeze(0))
150
+
151
+ elements = []
152
+ for tok, lab_id in zip(tokens, preds):
153
+ if tok in {"[CLS]", "[SEP]", "[PAD]"}:
154
+ continue
155
+ label = id2label[lab_id]
156
+ if label != "O":
157
+ elements.append((tok, label))
158
+
159
+ return elements
160
+
161
+ # -------------------------
162
+ # Triple Extraction
163
+ # -------------------------
164
+ triples_df = pd.read_excel(TRIPLES_PATH)
165
+
166
+ def group_elements(elements):
167
+ d = {}
168
+ for tok, lab in elements:
169
+ d.setdefault(lab, []).append(tok)
170
+ return d
171
+
172
+ def extract_relations(frame_name, elements):
173
+ fe_dict = group_elements(elements)
174
+ rows = triples_df[triples_df["Frame"] == frame_name]
175
+
176
+ relations = []
177
+ for _, r in rows.iterrows():
178
+ if r["Subject"] in fe_dict and r["Object"] in fe_dict:
179
+ for s in fe_dict[r["Subject"]]:
180
+ for o in fe_dict[r["Object"]]:
181
+ relations.append({
182
+ "subject": s,
183
+ "relation": r["Relation"],
184
+ "object": o,
185
+ "subject_fe": r["Subject"],
186
+ "object_fe": r["Object"]
187
+ })
188
+ return relations
189
+
190
+ # -------------------------
191
+ # Sentence Utilities
192
+ # -------------------------
193
+ def split_sentences(text):
194
+ sentences = re.split(r'[\.!\؟…]+', text)
195
+ return [s.strip() for s in sentences if s.strip()]
196
+
197
+ CONDITIONAL_PATTERNS = [
198
+ r'\bاگر\b',
199
+ r'\bچنانچه\b',
200
+ r'\bدر صورتی که\b',
201
+ r'\bهرگاه\b'
202
+ ]
203
+
204
+ def is_conditional(sentence):
205
+ return any(re.search(p, sentence) for p in CONDITIONAL_PATTERNS)
206
+
207
+ def split_condition(sentence):
208
+ if "،" in sentence:
209
+ c, r = sentence.split("،", 1)
210
+ return c.strip(), r.strip()
211
+ return sentence, ""
212
+
213
+ # -------------------------
214
+ # SPIN Rule Builder
215
+ # -------------------------
216
+ def build_spin_rule(if_triples, then_triples, rule_id):
217
+ if not if_triples or not then_triples:
218
+ return None
219
+
220
+ def t2s(t):
221
+ return f"({t['subject']} {t['relation']} {t['object']})"
222
+
223
+ if_part = " AND ".join(t2s(t) for t in if_triples)
224
+ then_part = " AND ".join(t2s(t) for t in then_triples)
225
+
226
+ return f"""
227
+ :Rule{rule_id} a spin:Rule ;
228
+ spin:body [
229
+ a sp:Ask ;
230
+ sp:text \"\"\"
231
+ IF {if_part}
232
+ THEN {then_part}
233
+ \"\"\"
234
+ ] .
235
+ """.strip()
236
+
237
+ # -------------------------
238
+ # Analyze One Sentence
239
+ # -------------------------
240
+ def analyze_sentence(sentence):
241
+ frame, sim = predict_frame(sentence)
242
+
243
+ if frame is None:
244
+ return {
245
+ "frame": "خارج از دامنه",
246
+ "similarity": round(sim, 3),
247
+ "elements": [],
248
+ "relations": []
249
+ }
250
+
251
+ elements = predict_frame_elements(sentence, frame)
252
+ relations = extract_relations(frame, elements)
253
+
254
+ return {
255
+ "frame": frame,
256
+ "similarity": round(sim, 3),
257
+ "elements": elements,
258
+ "relations": relations
259
+ }
260
+
261
+ # -------------------------
262
+ # Main Pipeline
263
+ # -------------------------
264
+ def analyze(text):
265
+ sentences = split_sentences(text)
266
+ results = []
267
+ rule_id = 1
268
+
269
+ for sent in sentences:
270
+ if is_conditional(sent):
271
+ cond_text, res_text = split_condition(sent)
272
+
273
+ cond_res = analyze_sentence(cond_text) if cond_text else None
274
+ res_res = analyze_sentence(res_text) if res_text else None
275
+
276
+ spin_rule = build_spin_rule(
277
+ cond_res["relations"],
278
+ res_res["relations"],
279
+ rule_id
280
+ ) if cond_res and res_res else None
281
+
282
+ rule_id += 1
283
+
284
+ results.append({
285
+ "sentence": sent,
286
+ "type": "conditional",
287
+ "has_rule": spin_rule is not None,
288
+ "condition": cond_res,
289
+ "result": res_res,
290
+ "spin_rule": spin_rule
291
+ })
292
+ else:
293
+ simple_res = analyze_sentence(sent)
294
+ results.append({
295
+ "sentence": sent,
296
+ "type": "simple",
297
+ **simple_res
298
+ })
299
+
300
+ return results
301
+
302
+ # -------------------------
303
+ # Gradio UI
304
+ # -------------------------
305
+ demo = gr.Interface(
306
+ fn=analyze,
307
+ inputs=gr.Textbox(
308
+ label="متن فارسی",
309
+ placeholder="مثال: اگر علی به تهران برود، خوشحال می‌شود. او دوستانش را ملاقات کرد."
310
+ ),
311
+ outputs=gr.JSON(label="خروجی"),
312
+ title="Persian Semantic Frame & Rule Extractor",
313
+ description="تشخیص فریم، عناصر معنایی، triple و قوانین SPIN"
314
+ )
315
+
316
+ if __name__ == "__main__":
317
+ demo.launch(server_name="0.0.0.0", server_port=7860)
frame_triples.xlsx ADDED
Binary file (25.6 kB). View file
 
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ transformers
3
+ sentencepiece
4
+ pandas
5
+ numpy
6
+ openpyxl
7
+ gradio
8
+ huggingface_hub