OmidSakaki's picture
Update app.py
7265ce7 verified
import gradio as gr
from transformers import pipeline
import torch
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity
import numpy as np
import re
# لیست کلمات توقف دستی
MANUAL_STOPWORDS = ['و', 'در', 'به', 'از', 'که', 'این', 'را', 'با', 'است', 'برای', 'روی', 'یک', 'ها', 'های', 'می', 'شود', 'شده', 'کرد', 'شدن']
class PersianRAGSystem:
def __init__(self):
# تنظیمات اولیه
self.device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {self.device}")
# استفاده از pipeline برای تولید متن
print("Loading text generation model...")
try:
self.generator = pipeline(
"text-generation",
model="HooshvareLab/gpt2-fa",
device=0 if self.device == "cuda" else -1,
torch_dtype=torch.float16 if self.device == "cuda" else torch.float32,
)
print("مدل متن‌سازی بارگذاری شد!")
except Exception as e:
print(f"خطا در بارگذاری مدل: {e}")
# Fallback to CPU if CUDA fails
self.generator = pipeline(
"text-generation",
model="HooshvareLab/gpt2-fa",
device=-1,
torch_dtype=torch.float32,
)
print("مدل با استفاده از CPU بارگذاری شد!")
# استفاده از TF-IDF برای محاسبه شباهت
self.vectorizer = TfidfVectorizer(stop_words=MANUAL_STOPWORDS)
self.tfidf_matrix = None
self.documents = []
self.section_titles = []
self.load_sample_data()
def load_sample_data(self):
"""بارگذاری داده‌های مربوط به گروه مپنا با ساختار معنایی"""
mapna_text = """
گروه مپنا تعریف:
- گروه مپنا گروه صنعتی پیشرو در ایران است.
- سال تأسیس: ۱۳۷۱.
- نام مپنا از مدیریت طرح‌های نیروگاهی گرفته شده.
- نوع شرکت: conglomerate صنعتی متنوع.
فعالیت‌های اصلی مپنا:
- مهندسی و ساخت نیروگاه EPC.
- خدمات نفت و گاز.
- حمل و نقل ریلی.
- انرژی تجدیدپذیر باد خورشید آب.
- تجهیزات صنعتی.
- خدمات فنی.
شرکت‌های زیرمجموعه مپنا:
- مپنا اپکو مهندسی تجهیزات نیروگاهی.
- بهره‌برداری و نگهداری مپنا.
- توربین مپنا.
- ژنراتور مپنا.
- بین‌المللی مپنا دبی.
- صنایع ریلی مپنا.
- انرژی خورشیدی مپنا.
- انرژی بادی مپنا.
پروژه‌های مهم مپنا:
- نیروگاه سیکل ترکیبی ایران.
- مزارع بادی منجیل تاکستان.
- نیروگاه خورشیدی مرکز ایران.
- برقی‌سازی راه‌آهن.
- پردازش نفت و گاز.
- اتوماسیون صنعتی.
حضور بین‌المللی مپنا:
- عراق ساخت نیروگاه.
- سوریه تولید برق.
- عمان خدمات نفت گاز.
- ترکیه تجهیزات انرژی.
- آذربایجان زیرساخت.
تحقیق و توسعه مپنا:
- مرکز تحقیقات سال ۱۳۸۷.
- فناوری توربین انرژی تجدیدپذیر اتوماسیون.
- همکاری دانشگاه‌ها تحقیقاتی.
- اختراع انرژی صنعتی.
عملکرد مالی مپنا:
- درآمد بیش از ۲ میلیارد دلار سالانه.
- رشد مستمر دهه گذشته.
- درآمد متنوع بخش‌ها.
- صادرات قوی منطقه‌ای.
ابتکارات پایداری مپنا:
- توسعه انرژی تجدیدپذیر.
- حفاظت محیط زیست.
- توسعه جامعه.
- بهره‌وری انرژی.
جوایز مپنا:
- صادرات ملی.
- برترین صنعتی ایران.
- تعالی کیفیت انرژی.
- مسئولیت محیط زیستی.
"""
# تقسیم متن به بخش‌های معنایی و ذخیره عناوین
sections = re.split(r'\n\s*\n', mapna_text.strip())
self.documents = []
self.section_titles = []
for section in sections:
section = section.strip()
if section:
# استخراج عنوان بخش
title = section.split('\n')[0].strip(':')
self.section_titles.append(title)
self.documents.append(section)
# ایجاد TF-IDF matrix
self.tfidf_matrix = self.vectorizer.fit_transform(self.documents)
print(f"ایندکس کردن {len(self.documents)} بخش معنایی انجام شد")
def find_relevant_documents(self, question, top_k=3):
"""یافتن مستندات مرتبط با سوال با تطبیق کلمات کلیدی و TF-IDF"""
if self.tfidf_matrix is None:
return []
question_vec = self.vectorizer.transform([question])
# محاسبه شباهت
similarities = cosine_similarity(question_vec, self.tfidf_matrix)
# تطبیق کلمات کلیدی سوال با عناوین بخش‌ها
question_words = set(question.lower().split())
best_match_idx = None
max_overlap = 0
for idx, title in enumerate(self.section_titles):
title_words = set(title.lower().split())
overlap = len(question_words.intersection(title_words))
if overlap > max_overlap and similarities[0][idx] > 0.3:
max_overlap = overlap
best_match_idx = idx
if best_match_idx is not None:
return [self.documents[best_match_idx]]
# در صورت عدم تطبیق، استفاده از TF-IDF
top_indices = similarities.argsort()[0][-top_k:][::-1]
relevant_docs = []
for i in top_indices:
if similarities[0][i] > 0.3:
relevant_docs.append(self.documents[i])
return relevant_docs[:1] if relevant_docs else []
def extract_answer_from_context(self, question, context):
"""استخراج پاسخ مستقیم از context با بازگشت کل بخش معنایی"""
sections = context.split('\n\n')
sections = [s.strip() for s in sections if s.strip()]
if not sections:
return None
# ایجاد TF-IDF برای بخش‌ها و سوال
local_vectorizer = TfidfVectorizer(stop_words=MANUAL_STOPWORDS)
section_vecs = local_vectorizer.fit_transform(sections)
question_vec = local_vectorizer.transform([question])
# محاسبه شباهت
similarities = cosine_similarity(question_vec, section_vecs)
# یافتن بخش با بیشترین شباهت
best_index = similarities.argsort()[0][-1]
if similarities[0][best_index] > 0.35:
# بازگرداندن کل بخش معنایی با قالب‌بندی زیبا
answer = sections[best_index]
answer = answer.replace('-', '\n-')
return answer
return None
def generate_simple_response(self, context, question):
"""تولید پاسخ ساده با مدل با prompt بهبود یافته"""
try:
prompt = f"""بر اساس اطلاعات زیر، به سوال پاسخ دقیق و کامل بده. پاسخ باید مستقیماً از اطلاعات داده‌شده استخراج شود و هیچ اطلاعات اضافی، غیرمرتبط یا خلاقانه اضافه نشود. اگر سوال به فهرستی اشاره دارد، تمام موارد مرتبط را به‌صورت کامل و بدون تغییر فهرست کن. پاسخ را به‌صورت متن ساده یا فهرست ارائه کن و از تولید محتوای جدید خودداری کن:
اطلاعات: {context}
سوال: {question}
پاسخ دقیق و کامل:"""
result = self.generator(
prompt,
max_new_tokens=400,
num_return_sequences=1,
temperature=0.05,
do_sample=True,
top_p=0.95,
repetition_penalty=2.5
)
response = result[0]['generated_text']
response = response.replace(prompt, "").strip()
# پاک‌سازی پاسخ: حفظ فهرست‌های کامل
lines = response.split('\n')
clean_response = []
for line in lines:
line = line.strip()
if line and not line.startswith('پاسخ دقیق و کامل:'):
clean_response.append(line)
# اگر پاسخ فهرستی است، تمام موارد را نگه دار
if any(line.startswith('-') for line in clean_response):
return '\n'.join(clean_response)
return '\n'.join(clean_response) or "پاسخ به این سوال در اطلاعات من وجود ندارد."
except Exception as e:
return f"خطا در تولید پاسخ: {str(e)}"
def answer_question(self, question):
"""پاسخ به سوال"""
if not self.documents:
return "سیستم هنوز اسناد را ایندکس نکرده است."
# یافتن مستندات مرتبط
relevant_docs = self.find_relevant_documents(question)
if not relevant_docs:
return "متاسفم، پاسخ به این سوال در اطلاعات من وجود ندارد."
# استخراج context
context = "\n\n".join(relevant_docs)
# ابتدا سعی کن پاسخ را مستقیماً از context استخراج کنی
direct_answer = self.extract_answer_from_context(question, context)
if direct_answer:
return direct_answer
# اگر سوال به فهرست اشاره دارد، از مدل generative استفاده نشود
if any(keyword in question for keyword in ['چیست', 'کدام', 'شامل']):
return "متاسفم، پاسخ دقیق در اطلاعات یافت نشد."
# در غیر این صورت، از مدل generative استفاده کن
response = self.generate_simple_response(context, question)
return response
# سوالات نمونه
sample_questions = [
"گروه مپنا چیست؟",
"فعالیت‌های اصلی مپنا چیست؟",
"شرکت‌های زیرمجموعه مپنا کدام‌اند؟",
"پروژه‌های مهم مپنا چیست؟",
"حضور بین‌المللی مپنا در کدام کشورهاست؟",
"تحقیق و توسعه مپنا شامل چیست؟",
"عملکرد مالی مپنا چگونه است؟",
"جوایز مپنا چیست؟",
"ابتکارات پایداری مپنا شامل چیست؟",
"سال تأسیس مپنا چه زمانی است؟"
]
# ایجاد نمونه از سیستم
rag_system = PersianRAGSystem()
# رابط کاربری Gradio
with gr.Blocks(title="سامانه پاسخگویی هوشمند فارسی - گروه مپنا", theme=gr.themes.Soft()) as demo:
gr.Markdown("""
<div style="text-align: right; direction: rtl;">
سامانه پاسخگویی هوشمند فارسی
این سیستم به سوالات شما درباره گروه مپنا پاسخ می‌دهد. سوال خود را در کادر زیر وارد کنید یا از سوالات نمونه استفاده کنید.
</div>
""")
with gr.Row():
with gr.Column(scale=3):
chatbot = gr.Chatbot(height=500, label="گفتگو", show_copy_button=True)
with gr.Row():
question_input = gr.Textbox(
placeholder="سوال خود را بپرسید...",
label="سوال",
scale=4,
elem_classes="question-input"
)
submit_btn = gr.Button("ارسال سوال", variant="primary", scale=1)
clear = gr.Button("پاک کردن چت", variant="secondary")
with gr.Column(scale=1):
gr.Markdown("""
<div style="text-align: right; direction: rtl;">
سوالات نمونه
برای تست سیستم، می‌توانید روی یکی از سوالات زیر کلیک کنید:
</div>
""")
# ایجاد دکمه‌های سوالات نمونه در یک Grid
with gr.Column():
for question in sample_questions:
btn = gr.Button(
question,
size="sm",
variant="secondary",
elem_classes="sample-question-btn"
)
btn.click(
fn=lambda q=question: q,
outputs=question_input
)
# توابع کنترل کننده رویدادها
def user(user_message, history):
return "", history + [[user_message, None]]
def bot(history):
question = history[-1][0]
answer = rag_system.answer_question(question)
history[-1][1] = answer
return history
# اتصال رویدادها
question_input.submit(user, [question_input, chatbot], [question_input, chatbot], queue=False).then(
bot, chatbot, chatbot
)
submit_btn.click(user, [question_input, chatbot], [question_input, chatbot], queue=False).then(
bot, chatbot, chatbot
)
clear.click(lambda: None, None, chatbot, queue=False)
# افزودن استایل‌های سفارشی برای بهبود ظاهر
css = """
.sample-question-btn {
width: 100%;
margin-bottom: 5px;
text-align: right;
direction: rtl;
white-space: normal;
min-height: 50px;
}
.question-input {
direction: rtl;
text-align: right;
}
h1, h2, h3 {
text-align: right;
direction: rtl;
}
"""
demo.css = css
if __name__ == "__main__":
demo.launch(share=False)