# -*- coding: utf-8 -*- import os import json import re import torch import torch.nn as nn import torch.nn.functional as F import numpy as np import pandas as pd import gradio as gr from transformers import ( AutoTokenizer, AutoModel, AutoModelForTokenClassification ) from huggingface_hub import snapshot_download # ------------------------- # تنظیمات کلی # ------------------------- device = torch.device("cpu") FRAME_DET_REPO = "PooryaPiroozfar/frame-detection-parsbert" FE_REPO = "PooryaPiroozfar/srl-frame-elements-parsbert" FRAME_DET_DIR = "models/frame_detection" FE_BASE_DIR = "models/frame_elements" TRIPLES_PATH = "frame_triples.xlsx" THRESHOLD = 0.25 frame_names = [ "Activity_finish","Activity_start","Aging","Attaching","Attempt", "Becoming","Being_born","Borrowing","Causation","Chatting", "Choosing","Closure","Clothing","Cutting","Damaging","Desiring","Discussion", "Emphasizing","Food","Installing","Locating","Memory","Morality_evaluation", "Motion","Offering","Practice","Project","Publishing","Religious_belief", "Removing","Request","Residence","Sharing","Taking","Telling","Travel", "Using","Visiting","Waiting","Work" ] # ------------------------- # دانلود مدل‌ها # ------------------------- if not os.path.exists(FRAME_DET_DIR): snapshot_download(repo_id=FRAME_DET_REPO, local_dir=FRAME_DET_DIR) if not os.path.exists(FE_BASE_DIR): snapshot_download(repo_id=FE_REPO, local_dir=FE_BASE_DIR) # ------------------------- # Sentence Encoder (ParsBERT) # ------------------------- encoder_name = "HooshvareLab/bert-base-parsbert-uncased" sent_tokenizer = AutoTokenizer.from_pretrained(encoder_name) sent_encoder = AutoModel.from_pretrained(encoder_name).to(device) sent_encoder.eval() def get_embedding(text): inputs = sent_tokenizer( text, return_tensors="pt", truncation=True, padding=True, max_length=128 ).to(device) with torch.no_grad(): outputs = sent_encoder(**inputs) token_embeddings = outputs.last_hidden_state mask = inputs["attention_mask"].unsqueeze(-1).expand(token_embeddings.size()).float() summed = torch.sum(token_embeddings * mask, dim=1) lengths = torch.clamp(mask.sum(dim=1), min=1e-9) return (summed / lengths).squeeze(0) # ------------------------- # Frame Detection Model # ------------------------- class FrameSimilarityModel(nn.Module): def __init__(self, emb_dim, frame_emb_init): super().__init__() self.proj = nn.Linear(emb_dim, emb_dim) self.frame_embeddings = nn.Parameter( torch.tensor(frame_emb_init, dtype=torch.float32) ) def forward(self, sent_emb): sent_proj = F.normalize(self.proj(sent_emb), dim=-1) frames = F.normalize(self.frame_embeddings, dim=-1) return torch.matmul(sent_proj, frames.T) frame_embs = np.load(os.path.join(FRAME_DET_DIR, "trained_frame_embeddings.npy")) frame_model = FrameSimilarityModel( emb_dim=768, frame_emb_init=frame_embs ).to(device) state_dict = torch.load( os.path.join(FRAME_DET_DIR, "best_frame_margin_model.pt"), map_location="cpu" ) frame_model.load_state_dict(state_dict) frame_model.eval() def predict_frame(sentence): emb = get_embedding(sentence).unsqueeze(0) with torch.no_grad(): sims = frame_model(emb) max_sim, idx = torch.max(sims, dim=1) if max_sim.item() < THRESHOLD: return None, max_sim.item() return frame_names[idx.item()], max_sim.item() # ------------------------- # Frame Elements (SRL) # ------------------------- def predict_frame_elements(sentence, frame_name): frame_dir = os.path.join(FE_BASE_DIR, frame_name) if not os.path.exists(frame_dir): return [] with open(os.path.join(frame_dir, "label2id.json"), encoding="utf-8") as f: label2id = json.load(f) id2label = {int(v): k for k, v in label2id.items()} tokenizer = AutoTokenizer.from_pretrained(frame_dir) model = AutoModelForTokenClassification.from_pretrained( frame_dir, num_labels=len(label2id), id2label=id2label, label2id=label2id ).to(device) model.eval() inputs = tokenizer(sentence, return_tensors="pt", truncation=True, max_length=128) with torch.no_grad(): outputs = model(**inputs) preds = torch.argmax(outputs.logits, dim=-1).squeeze(0).numpy() tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"].squeeze(0)) elements = [] for tok, lab_id in zip(tokens, preds): if tok in {"[CLS]", "[SEP]", "[PAD]"}: continue label = id2label[lab_id] if label != "O": elements.append((tok, label)) return elements # ------------------------- # Triple Extraction # ------------------------- triples_df = pd.read_excel(TRIPLES_PATH) def group_elements(elements): d = {} for tok, lab in elements: d.setdefault(lab, []).append(tok) return d def extract_relations(frame_name, elements): fe_dict = group_elements(elements) rows = triples_df[triples_df["Frame"] == frame_name] relations = [] for _, r in rows.iterrows(): if r["Subject"] in fe_dict and r["Object"] in fe_dict: for s in fe_dict[r["Subject"]]: for o in fe_dict[r["Object"]]: relations.append({ "subject": s, "relation": r["Relation"], "object": o, "subject_fe": r["Subject"], "object_fe": r["Object"] }) return relations # ------------------------- # Sentence Utilities # ------------------------- def split_sentences(text): sentences = re.split(r'[\.!\؟…]+', text) return [s.strip() for s in sentences if s.strip()] CONDITIONAL_PATTERNS = [ r'\bاگر\b', r'\bچنانچه\b', r'\bدر صورتی که\b', r'\bهرگاه\b' ] def is_conditional(sentence): return any(re.search(p, sentence) for p in CONDITIONAL_PATTERNS) def split_condition(sentence): if "،" in sentence: c, r = sentence.split("،", 1) return c.strip(), r.strip() return sentence, "" # ------------------------- # SPIN Rule Builder # ------------------------- def build_spin_rule(if_triples, then_triples, rule_id): if not if_triples or not then_triples: return None def t2s(t): return f"({t['subject']} {t['relation']} {t['object']})" if_part = " AND ".join(t2s(t) for t in if_triples) then_part = " AND ".join(t2s(t) for t in then_triples) return f""" :Rule{rule_id} a spin:Rule ; spin:body [ a sp:Ask ; sp:text \"\"\" IF {if_part} THEN {then_part} \"\"\" ] . """.strip() # ------------------------- # Analyze One Sentence # ------------------------- def analyze_sentence(sentence): frame, sim = predict_frame(sentence) if frame is None: return { "frame": "خارج از دامنه", "similarity": round(sim, 3), "elements": [], "relations": [] } elements = predict_frame_elements(sentence, frame) relations = extract_relations(frame, elements) return { "frame": frame, "similarity": round(sim, 3), "elements": elements, "relations": relations } # ------------------------- # Main Pipeline # ------------------------- def analyze(text): sentences = split_sentences(text) results = [] rule_id = 1 for sent in sentences: if is_conditional(sent): cond_text, res_text = split_condition(sent) cond_res = analyze_sentence(cond_text) if cond_text else None res_res = analyze_sentence(res_text) if res_text else None spin_rule = build_spin_rule( cond_res["relations"], res_res["relations"], rule_id ) if cond_res and res_res else None rule_id += 1 results.append({ "جمله": sent, "نوع_جمله": "شرطی", "دارای_قانون": spin_rule is not None, "شرط": cond_res, "نتیجه": res_res, "قانون_SPIN": spin_rule }) else: simple_res = analyze_sentence(sent) results.append({ "sentence": sent, "type": "simple", **simple_res }) return results # ------------------------- # Gradio UI # ------------------------- demo = gr.Interface( fn=analyze, inputs=gr.Textbox( label="متن فارسی", placeholder="مثال: اگر علی از تهران به مشهد برود، شغل خوبی انتخاب می کند" ), outputs=gr.JSON(label="خروجی"), title="Persian Semantic Frame & Rule Extractor", description="تشخیص فریم، عناصر معنایی، triple و قوانین SPIN" ) if __name__ == "__main__": demo.launch(server_name="0.0.0.0", server_port=7860)