PooryaPiroozfar's picture
Update app.py
02b3114 verified
# -*- coding: utf-8 -*-
import os
import json
import re
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import pandas as pd
import gradio as gr
from transformers import (
AutoTokenizer,
AutoModel,
AutoModelForTokenClassification
)
from huggingface_hub import snapshot_download
# -------------------------
# تنظیمات کلی
# -------------------------
device = torch.device("cpu")
FRAME_DET_REPO = "PooryaPiroozfar/frame-detection-parsbert"
FE_REPO = "PooryaPiroozfar/srl-frame-elements-parsbert"
FRAME_DET_DIR = "models/frame_detection"
FE_BASE_DIR = "models/frame_elements"
TRIPLES_PATH = "frame_triples.xlsx"
THRESHOLD = 0.25
frame_names = [
"Activity_finish","Activity_start","Aging","Attaching","Attempt",
"Becoming","Being_born","Borrowing","Causation","Chatting",
"Choosing","Closure","Clothing","Cutting","Damaging","Desiring","Discussion",
"Emphasizing","Food","Installing","Locating","Memory","Morality_evaluation",
"Motion","Offering","Practice","Project","Publishing","Religious_belief",
"Removing","Request","Residence","Sharing","Taking","Telling","Travel",
"Using","Visiting","Waiting","Work"
]
# -------------------------
# دانلود مدل‌ها
# -------------------------
if not os.path.exists(FRAME_DET_DIR):
snapshot_download(repo_id=FRAME_DET_REPO, local_dir=FRAME_DET_DIR)
if not os.path.exists(FE_BASE_DIR):
snapshot_download(repo_id=FE_REPO, local_dir=FE_BASE_DIR)
# -------------------------
# Sentence Encoder (ParsBERT)
# -------------------------
encoder_name = "HooshvareLab/bert-base-parsbert-uncased"
sent_tokenizer = AutoTokenizer.from_pretrained(encoder_name)
sent_encoder = AutoModel.from_pretrained(encoder_name).to(device)
sent_encoder.eval()
def get_embedding(text):
inputs = sent_tokenizer(
text,
return_tensors="pt",
truncation=True,
padding=True,
max_length=128
).to(device)
with torch.no_grad():
outputs = sent_encoder(**inputs)
token_embeddings = outputs.last_hidden_state
mask = inputs["attention_mask"].unsqueeze(-1).expand(token_embeddings.size()).float()
summed = torch.sum(token_embeddings * mask, dim=1)
lengths = torch.clamp(mask.sum(dim=1), min=1e-9)
return (summed / lengths).squeeze(0)
# -------------------------
# Frame Detection Model
# -------------------------
class FrameSimilarityModel(nn.Module):
def __init__(self, emb_dim, frame_emb_init):
super().__init__()
self.proj = nn.Linear(emb_dim, emb_dim)
self.frame_embeddings = nn.Parameter(
torch.tensor(frame_emb_init, dtype=torch.float32)
)
def forward(self, sent_emb):
sent_proj = F.normalize(self.proj(sent_emb), dim=-1)
frames = F.normalize(self.frame_embeddings, dim=-1)
return torch.matmul(sent_proj, frames.T)
frame_embs = np.load(os.path.join(FRAME_DET_DIR, "trained_frame_embeddings.npy"))
frame_model = FrameSimilarityModel(
emb_dim=768,
frame_emb_init=frame_embs
).to(device)
state_dict = torch.load(
os.path.join(FRAME_DET_DIR, "best_frame_margin_model.pt"),
map_location="cpu"
)
frame_model.load_state_dict(state_dict)
frame_model.eval()
def predict_frame(sentence):
emb = get_embedding(sentence).unsqueeze(0)
with torch.no_grad():
sims = frame_model(emb)
max_sim, idx = torch.max(sims, dim=1)
if max_sim.item() < THRESHOLD:
return None, max_sim.item()
return frame_names[idx.item()], max_sim.item()
# -------------------------
# Frame Elements (SRL)
# -------------------------
def predict_frame_elements(sentence, frame_name):
frame_dir = os.path.join(FE_BASE_DIR, frame_name)
if not os.path.exists(frame_dir):
return []
with open(os.path.join(frame_dir, "label2id.json"), encoding="utf-8") as f:
label2id = json.load(f)
id2label = {int(v): k for k, v in label2id.items()}
tokenizer = AutoTokenizer.from_pretrained(frame_dir)
model = AutoModelForTokenClassification.from_pretrained(
frame_dir,
num_labels=len(label2id),
id2label=id2label,
label2id=label2id
).to(device)
model.eval()
inputs = tokenizer(sentence, return_tensors="pt", truncation=True, max_length=128)
with torch.no_grad():
outputs = model(**inputs)
preds = torch.argmax(outputs.logits, dim=-1).squeeze(0).numpy()
tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"].squeeze(0))
elements = []
for tok, lab_id in zip(tokens, preds):
if tok in {"[CLS]", "[SEP]", "[PAD]"}:
continue
label = id2label[lab_id]
if label != "O":
elements.append((tok, label))
return elements
# -------------------------
# Triple Extraction
# -------------------------
triples_df = pd.read_excel(TRIPLES_PATH)
def group_elements(elements):
d = {}
for tok, lab in elements:
d.setdefault(lab, []).append(tok)
return d
def extract_relations(frame_name, elements):
fe_dict = group_elements(elements)
rows = triples_df[triples_df["Frame"] == frame_name]
relations = []
for _, r in rows.iterrows():
if r["Subject"] in fe_dict and r["Object"] in fe_dict:
for s in fe_dict[r["Subject"]]:
for o in fe_dict[r["Object"]]:
relations.append({
"subject": s,
"relation": r["Relation"],
"object": o,
"subject_fe": r["Subject"],
"object_fe": r["Object"]
})
return relations
# -------------------------
# Sentence Utilities
# -------------------------
def split_sentences(text):
sentences = re.split(r'[\.!\؟…]+', text)
return [s.strip() for s in sentences if s.strip()]
CONDITIONAL_PATTERNS = [
r'\bاگر\b',
r'\bچنانچه\b',
r'\bدر صورتی که\b',
r'\bهرگاه\b'
]
def is_conditional(sentence):
return any(re.search(p, sentence) for p in CONDITIONAL_PATTERNS)
def split_condition(sentence):
if "،" in sentence:
c, r = sentence.split("،", 1)
return c.strip(), r.strip()
return sentence, ""
# -------------------------
# SPIN Rule Builder
# -------------------------
def build_spin_rule(if_triples, then_triples, rule_id):
if not if_triples or not then_triples:
return None
def t2s(t):
return f"({t['subject']} {t['relation']} {t['object']})"
if_part = " AND ".join(t2s(t) for t in if_triples)
then_part = " AND ".join(t2s(t) for t in then_triples)
return f"""
:Rule{rule_id} a spin:Rule ;
spin:body [
a sp:Ask ;
sp:text \"\"\"
IF {if_part}
THEN {then_part}
\"\"\"
] .
""".strip()
# -------------------------
# Analyze One Sentence
# -------------------------
def analyze_sentence(sentence):
frame, sim = predict_frame(sentence)
if frame is None:
return {
"frame": "خارج از دامنه",
"similarity": round(sim, 3),
"elements": [],
"relations": []
}
elements = predict_frame_elements(sentence, frame)
relations = extract_relations(frame, elements)
return {
"frame": frame,
"similarity": round(sim, 3),
"elements": elements,
"relations": relations
}
# -------------------------
# Main Pipeline
# -------------------------
def analyze(text):
sentences = split_sentences(text)
results = []
rule_id = 1
for sent in sentences:
if is_conditional(sent):
cond_text, res_text = split_condition(sent)
cond_res = analyze_sentence(cond_text) if cond_text else None
res_res = analyze_sentence(res_text) if res_text else None
spin_rule = build_spin_rule(
cond_res["relations"],
res_res["relations"],
rule_id
) if cond_res and res_res else None
rule_id += 1
results.append({
"جمله": sent,
"نوع_جمله": "شرطی",
"دارای_قانون": spin_rule is not None,
"شرط": cond_res,
"نتیجه": res_res,
"قانون_SPIN": spin_rule
})
else:
simple_res = analyze_sentence(sent)
results.append({
"sentence": sent,
"type": "simple",
**simple_res
})
return results
# -------------------------
# Gradio UI
# -------------------------
demo = gr.Interface(
fn=analyze,
inputs=gr.Textbox(
label="متن فارسی",
placeholder="مثال: اگر علی از تهران به مشهد برود، شغل خوبی انتخاب می کند"
),
outputs=gr.JSON(label="خروجی"),
title="Persian Semantic Frame & Rule Extractor",
description="تشخیص فریم، عناصر معنایی، triple و قوانین SPIN"
)
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860)