Spaces:
Sleeping
Sleeping
| # -*- coding: utf-8 -*- | |
| import os | |
| import json | |
| import re | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import numpy as np | |
| import pandas as pd | |
| import gradio as gr | |
| from transformers import ( | |
| AutoTokenizer, | |
| AutoModel, | |
| AutoModelForTokenClassification | |
| ) | |
| from huggingface_hub import snapshot_download | |
| # ------------------------- | |
| # تنظیمات کلی | |
| # ------------------------- | |
| device = torch.device("cpu") | |
| FRAME_DET_REPO = "PooryaPiroozfar/frame-detection-parsbert" | |
| FE_REPO = "PooryaPiroozfar/srl-frame-elements-parsbert" | |
| FRAME_DET_DIR = "models/frame_detection" | |
| FE_BASE_DIR = "models/frame_elements" | |
| TRIPLES_PATH = "frame_triples.xlsx" | |
| THRESHOLD = 0.25 | |
| frame_names = [ | |
| "Activity_finish","Activity_start","Aging","Attaching","Attempt", | |
| "Becoming","Being_born","Borrowing","Causation","Chatting", | |
| "Choosing","Closure","Clothing","Cutting","Damaging","Desiring","Discussion", | |
| "Emphasizing","Food","Installing","Locating","Memory","Morality_evaluation", | |
| "Motion","Offering","Practice","Project","Publishing","Religious_belief", | |
| "Removing","Request","Residence","Sharing","Taking","Telling","Travel", | |
| "Using","Visiting","Waiting","Work" | |
| ] | |
| # ------------------------- | |
| # دانلود مدلها | |
| # ------------------------- | |
| if not os.path.exists(FRAME_DET_DIR): | |
| snapshot_download(repo_id=FRAME_DET_REPO, local_dir=FRAME_DET_DIR) | |
| if not os.path.exists(FE_BASE_DIR): | |
| snapshot_download(repo_id=FE_REPO, local_dir=FE_BASE_DIR) | |
| # ------------------------- | |
| # Sentence Encoder (ParsBERT) | |
| # ------------------------- | |
| encoder_name = "HooshvareLab/bert-base-parsbert-uncased" | |
| sent_tokenizer = AutoTokenizer.from_pretrained(encoder_name) | |
| sent_encoder = AutoModel.from_pretrained(encoder_name).to(device) | |
| sent_encoder.eval() | |
| def get_embedding(text): | |
| inputs = sent_tokenizer( | |
| text, | |
| return_tensors="pt", | |
| truncation=True, | |
| padding=True, | |
| max_length=128 | |
| ).to(device) | |
| with torch.no_grad(): | |
| outputs = sent_encoder(**inputs) | |
| token_embeddings = outputs.last_hidden_state | |
| mask = inputs["attention_mask"].unsqueeze(-1).expand(token_embeddings.size()).float() | |
| summed = torch.sum(token_embeddings * mask, dim=1) | |
| lengths = torch.clamp(mask.sum(dim=1), min=1e-9) | |
| return (summed / lengths).squeeze(0) | |
| # ------------------------- | |
| # Frame Detection Model | |
| # ------------------------- | |
| class FrameSimilarityModel(nn.Module): | |
| def __init__(self, emb_dim, frame_emb_init): | |
| super().__init__() | |
| self.proj = nn.Linear(emb_dim, emb_dim) | |
| self.frame_embeddings = nn.Parameter( | |
| torch.tensor(frame_emb_init, dtype=torch.float32) | |
| ) | |
| def forward(self, sent_emb): | |
| sent_proj = F.normalize(self.proj(sent_emb), dim=-1) | |
| frames = F.normalize(self.frame_embeddings, dim=-1) | |
| return torch.matmul(sent_proj, frames.T) | |
| frame_embs = np.load(os.path.join(FRAME_DET_DIR, "trained_frame_embeddings.npy")) | |
| frame_model = FrameSimilarityModel( | |
| emb_dim=768, | |
| frame_emb_init=frame_embs | |
| ).to(device) | |
| state_dict = torch.load( | |
| os.path.join(FRAME_DET_DIR, "best_frame_margin_model.pt"), | |
| map_location="cpu" | |
| ) | |
| frame_model.load_state_dict(state_dict) | |
| frame_model.eval() | |
| def predict_frame(sentence): | |
| emb = get_embedding(sentence).unsqueeze(0) | |
| with torch.no_grad(): | |
| sims = frame_model(emb) | |
| max_sim, idx = torch.max(sims, dim=1) | |
| if max_sim.item() < THRESHOLD: | |
| return None, max_sim.item() | |
| return frame_names[idx.item()], max_sim.item() | |
| # ------------------------- | |
| # Frame Elements (SRL) | |
| # ------------------------- | |
| def predict_frame_elements(sentence, frame_name): | |
| frame_dir = os.path.join(FE_BASE_DIR, frame_name) | |
| if not os.path.exists(frame_dir): | |
| return [] | |
| with open(os.path.join(frame_dir, "label2id.json"), encoding="utf-8") as f: | |
| label2id = json.load(f) | |
| id2label = {int(v): k for k, v in label2id.items()} | |
| tokenizer = AutoTokenizer.from_pretrained(frame_dir) | |
| model = AutoModelForTokenClassification.from_pretrained( | |
| frame_dir, | |
| num_labels=len(label2id), | |
| id2label=id2label, | |
| label2id=label2id | |
| ).to(device) | |
| model.eval() | |
| inputs = tokenizer(sentence, return_tensors="pt", truncation=True, max_length=128) | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| preds = torch.argmax(outputs.logits, dim=-1).squeeze(0).numpy() | |
| tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"].squeeze(0)) | |
| elements = [] | |
| for tok, lab_id in zip(tokens, preds): | |
| if tok in {"[CLS]", "[SEP]", "[PAD]"}: | |
| continue | |
| label = id2label[lab_id] | |
| if label != "O": | |
| elements.append((tok, label)) | |
| return elements | |
| # ------------------------- | |
| # Triple Extraction | |
| # ------------------------- | |
| triples_df = pd.read_excel(TRIPLES_PATH) | |
| def group_elements(elements): | |
| d = {} | |
| for tok, lab in elements: | |
| d.setdefault(lab, []).append(tok) | |
| return d | |
| def extract_relations(frame_name, elements): | |
| fe_dict = group_elements(elements) | |
| rows = triples_df[triples_df["Frame"] == frame_name] | |
| relations = [] | |
| for _, r in rows.iterrows(): | |
| if r["Subject"] in fe_dict and r["Object"] in fe_dict: | |
| for s in fe_dict[r["Subject"]]: | |
| for o in fe_dict[r["Object"]]: | |
| relations.append({ | |
| "subject": s, | |
| "relation": r["Relation"], | |
| "object": o, | |
| "subject_fe": r["Subject"], | |
| "object_fe": r["Object"] | |
| }) | |
| return relations | |
| # ------------------------- | |
| # Sentence Utilities | |
| # ------------------------- | |
| def split_sentences(text): | |
| sentences = re.split(r'[\.!\؟…]+', text) | |
| return [s.strip() for s in sentences if s.strip()] | |
| CONDITIONAL_PATTERNS = [ | |
| r'\bاگر\b', | |
| r'\bچنانچه\b', | |
| r'\bدر صورتی که\b', | |
| r'\bهرگاه\b' | |
| ] | |
| def is_conditional(sentence): | |
| return any(re.search(p, sentence) for p in CONDITIONAL_PATTERNS) | |
| def split_condition(sentence): | |
| if "،" in sentence: | |
| c, r = sentence.split("،", 1) | |
| return c.strip(), r.strip() | |
| return sentence, "" | |
| # ------------------------- | |
| # SPIN Rule Builder | |
| # ------------------------- | |
| def build_spin_rule(if_triples, then_triples, rule_id): | |
| if not if_triples or not then_triples: | |
| return None | |
| def t2s(t): | |
| return f"({t['subject']} {t['relation']} {t['object']})" | |
| if_part = " AND ".join(t2s(t) for t in if_triples) | |
| then_part = " AND ".join(t2s(t) for t in then_triples) | |
| return f""" | |
| :Rule{rule_id} a spin:Rule ; | |
| spin:body [ | |
| a sp:Ask ; | |
| sp:text \"\"\" | |
| IF {if_part} | |
| THEN {then_part} | |
| \"\"\" | |
| ] . | |
| """.strip() | |
| # ------------------------- | |
| # Analyze One Sentence | |
| # ------------------------- | |
| def analyze_sentence(sentence): | |
| frame, sim = predict_frame(sentence) | |
| if frame is None: | |
| return { | |
| "frame": "خارج از دامنه", | |
| "similarity": round(sim, 3), | |
| "elements": [], | |
| "relations": [] | |
| } | |
| elements = predict_frame_elements(sentence, frame) | |
| relations = extract_relations(frame, elements) | |
| return { | |
| "frame": frame, | |
| "similarity": round(sim, 3), | |
| "elements": elements, | |
| "relations": relations | |
| } | |
| # ------------------------- | |
| # Main Pipeline | |
| # ------------------------- | |
| def analyze(text): | |
| sentences = split_sentences(text) | |
| results = [] | |
| rule_id = 1 | |
| for sent in sentences: | |
| if is_conditional(sent): | |
| cond_text, res_text = split_condition(sent) | |
| cond_res = analyze_sentence(cond_text) if cond_text else None | |
| res_res = analyze_sentence(res_text) if res_text else None | |
| spin_rule = build_spin_rule( | |
| cond_res["relations"], | |
| res_res["relations"], | |
| rule_id | |
| ) if cond_res and res_res else None | |
| rule_id += 1 | |
| results.append({ | |
| "جمله": sent, | |
| "نوع_جمله": "شرطی", | |
| "دارای_قانون": spin_rule is not None, | |
| "شرط": cond_res, | |
| "نتیجه": res_res, | |
| "قانون_SPIN": spin_rule | |
| }) | |
| else: | |
| simple_res = analyze_sentence(sent) | |
| results.append({ | |
| "sentence": sent, | |
| "type": "simple", | |
| **simple_res | |
| }) | |
| return results | |
| # ------------------------- | |
| # Gradio UI | |
| # ------------------------- | |
| demo = gr.Interface( | |
| fn=analyze, | |
| inputs=gr.Textbox( | |
| label="متن فارسی", | |
| placeholder="مثال: اگر علی از تهران به مشهد برود، شغل خوبی انتخاب می کند" | |
| ), | |
| outputs=gr.JSON(label="خروجی"), | |
| title="Persian Semantic Frame & Rule Extractor", | |
| description="تشخیص فریم، عناصر معنایی، triple و قوانین SPIN" | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch(server_name="0.0.0.0", server_port=7860) | |