ProjectPipline / app.py
PooryaPiroozfar's picture
Upload app.py
660579d verified
# -*- coding: utf-8 -*-
"""SRL_Pipline_Docker.ipynb
Automatically generated by Colab.
Original file is located at
https://colab.research.google.com/drive/1FoWa87UBXFtiFB26Du-XNnrWckLJvmkD
"""
import os
import json
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import pandas as pd
import gradio as gr
from transformers import (
AutoTokenizer,
AutoModel,
AutoModelForTokenClassification
)
from huggingface_hub import snapshot_download
# -------------------------
# تنظیمات کلی
# -------------------------
device = torch.device("cpu")
FRAME_DET_REPO = "PooryaPiroozfar/frame-detection-parsbert"
FE_REPO = "PooryaPiroozfar/srl-frame-elements-parsbert"
FRAME_DET_DIR = "models/frame_detection"
FE_BASE_DIR = "models/frame_elements"
TRIPLES_PATH = "frame_triples.xlsx" # باید در repo باشد
THRESHOLD = 0.2
frame_names = [
"Activity_finish","Activity_start","Aging","Attaching","Attempt",
"Becoming","Being_born","Borrowing","Causation","Chatting",
"Choosing","Closure","Clothing","Cutting","Damaging","Desiring","Discussion",
"Emphasizing","Food","Installing","Locating","Memory","Morality_evaluation",
"Motion","Offering","Practice","Project","Publishing","Religious_belief",
"Removing","Request","Residence","Sharing","Taking","Telling","Travel",
"Using","Visiting","Waiting","Work"
]
# -------------------------
# دانلود مدل‌ها (یک‌بار)
# -------------------------
if not os.path.exists(FRAME_DET_DIR):
snapshot_download(repo_id=FRAME_DET_REPO, local_dir=FRAME_DET_DIR)
if not os.path.exists(FE_BASE_DIR):
snapshot_download(repo_id=FE_REPO, local_dir=FE_BASE_DIR)
# -------------------------
# Encoder (ParsBERT)
# -------------------------
encoder_name = "HooshvareLab/bert-base-parsbert-uncased"
sent_tokenizer = AutoTokenizer.from_pretrained(encoder_name)
sent_encoder = AutoModel.from_pretrained(encoder_name).to(device)
sent_encoder.eval()
def get_embedding(text):
inputs = sent_tokenizer(
text,
return_tensors="pt",
truncation=True,
padding=True,
max_length=128
).to(device)
with torch.no_grad():
outputs = sent_encoder(**inputs)
token_embeddings = outputs.last_hidden_state
mask = inputs["attention_mask"].unsqueeze(-1).expand(token_embeddings.size()).float()
summed = torch.sum(token_embeddings * mask, dim=1)
lengths = torch.clamp(mask.sum(dim=1), min=1e-9)
return (summed / lengths).squeeze(0)
# -------------------------
# مدل تشخیص فریم
# -------------------------
class FrameSimilarityModel(nn.Module):
def __init__(self, emb_dim, num_frames, frame_emb_init):
super().__init__()
self.proj = nn.Linear(emb_dim, emb_dim)
self.frame_embeddings = nn.Parameter(
torch.tensor(frame_emb_init, dtype=torch.float32)
)
def forward(self, sent_emb):
sent_proj = F.normalize(self.proj(sent_emb), dim=-1)
frames = F.normalize(self.frame_embeddings, dim=-1)
return torch.matmul(sent_proj, frames.T)
frame_embs = np.load(os.path.join(FRAME_DET_DIR, "trained_frame_embeddings.npy"))
frame_model = FrameSimilarityModel(
emb_dim=768,
num_frames=frame_embs.shape[0],
frame_emb_init=frame_embs
).to(device)
state_dict = torch.load(
os.path.join(FRAME_DET_DIR, "best_frame_margin_model.pt"),
map_location="cpu"
)
frame_model.load_state_dict(state_dict)
frame_model.eval()
def predict_frame(sentence):
emb = get_embedding(sentence).unsqueeze(0)
with torch.no_grad():
sims = frame_model(emb)
max_sim, idx = torch.max(sims, dim=1)
if max_sim.item() < THRESHOLD:
return None, max_sim.item()
return frame_names[idx.item()], max_sim.item()
# -------------------------
# Frame Elements
# -------------------------
def predict_frame_elements(sentence, frame_name):
frame_dir = os.path.join(FE_BASE_DIR, frame_name)
if not os.path.exists(frame_dir):
return []
with open(os.path.join(frame_dir, "label2id.json"), encoding="utf-8") as f:
label2id = json.load(f)
id2label = {int(v): k for k, v in label2id.items()}
tokenizer = AutoTokenizer.from_pretrained(frame_dir)
model = AutoModelForTokenClassification.from_pretrained(
frame_dir,
num_labels=len(label2id),
id2label=id2label,
label2id=label2id
).to(device)
model.eval()
inputs = tokenizer(sentence, return_tensors="pt", truncation=True, max_length=128)
with torch.no_grad():
outputs = model(**inputs)
preds = torch.argmax(outputs.logits, dim=-1).squeeze(0).numpy()
tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"].squeeze(0))
elements = []
for tok, lab_id in zip(tokens, preds):
if tok in {"[CLS]", "[SEP]", "[PAD]"}:
continue
label = id2label[lab_id]
if label != "O":
elements.append((tok, label))
return elements
# -------------------------
# Triple Extraction
# -------------------------
triples_df = pd.read_excel(TRIPLES_PATH)
def group_elements(elements):
d = {}
for tok, lab in elements:
d.setdefault(lab, []).append(tok)
return d
def extract_relations(frame_name, elements):
fe_dict = group_elements(elements)
rows = triples_df[triples_df["Frame"] == frame_name]
relations = []
for _, r in rows.iterrows():
if r["Subject"] in fe_dict and r["Object"] in fe_dict:
for s in fe_dict[r["Subject"]]:
for o in fe_dict[r["Object"]]:
relations.append({
"subject": s,
"relation": r["Relation"],
"object": o,
"subject_fe": r["Subject"],
"object_fe": r["Object"]
})
return relations
# -------------------------
# Pipeline اصلی
# -------------------------
def analyze(sentence):
frame, sim = predict_frame(sentence)
if frame is None:
return {
"frame": "خارج از دامنه",
"similarity": round(sim, 3),
"elements": [],
"relations": []
}
elements = predict_frame_elements(sentence, frame)
relations = extract_relations(frame, elements)
return {
"frame": frame,
"similarity": round(sim, 3),
"elements": elements,
"relations": relations
}
# -------------------------
# Gradio UI
# -------------------------
def ui(sentence):
return analyze(sentence)
demo = gr.Interface(
fn=ui,
inputs=gr.Textbox(
label="جمله فارسی",
placeholder="مثال: علی از تهران به مشهد سفر کرد"
),
outputs=gr.JSON(label="خروجی"),
title="Persian Semantic Frame & Triple Extractor",
description="تشخیص فریم، عناصر فریم و استخراج tripleهای معنایی"
)
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860)