Der11 / app.py
Derr11's picture
Update app.py
38e2f41 verified
raw
history blame
21.9 kB
import os
import numpy as np
import torch
import gradio as gr
import spaces
import warnings
warnings.filterwarnings("ignore")
from transformers import Qwen3OmniMoeForConditionalGeneration, Qwen3OmniMoeProcessor
from qwen_omni_utils import process_mm_info
# =========================================================
# Patches لتجاوز مشاكل التوافق في Qwen3-Omni
# =========================================================
def _patched_mark_tied_weights_as_initialized(self):
"""
تجاوز مشكلة lm_head في tied weights
"""
return
def _patched_init_weights(self, module):
"""
تجاوز مشكلة initializer_range في Qwen3OmniMoeTalkerConfig
"""
# نحاول الحصول على initializer_range، وإذا لم يكن موجود نستخدم قيمة افتراضية
try:
std = self.config.initializer_range
except AttributeError:
# قيمة افتراضية آمنة
std = 0.02
# تطبيق التهيئة الأساسية
if isinstance(module, torch.nn.Linear):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, torch.nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
def _patched_initialize_weights(self):
"""
تجاوز كامل لدالة initialize_weights في حالة استمرار المشاكل
"""
# نحاول التهيئة العادية، وإذا فشلت نتجاهلها
try:
# محاولة استخدام الدالة الأصلية إذا كانت موجودة
if hasattr(self, '_original_initialize_weights'):
self._original_initialize_weights()
else:
# تهيئة بسيطة آمنة
for module in self.modules():
if isinstance(module, (torch.nn.Linear, torch.nn.Embedding)):
if hasattr(module, 'weight'):
module.weight.data.normal_(mean=0.0, std=0.02)
if hasattr(module, 'bias') and module.bias is not None:
module.bias.data.zero_()
except Exception as e:
print(f"Warning: Could not initialize weights properly: {e}")
# نستمر بدون تهيئة - النموذج المحمل مسبقاً يجب أن يعمل
# تطبيق الـ patches قبل أي استدعاء لـ from_pretrained
def apply_patches():
"""تطبيق جميع الـ patches اللازمة"""
# Patch 1: تجاوز مشكلة lm_head
if hasattr(Qwen3OmniMoeForConditionalGeneration, "mark_tied_weights_as_initialized"):
Qwen3OmniMoeForConditionalGeneration.mark_tied_weights_as_initialized = (
_patched_mark_tied_weights_as_initialized
)
# Patch 2: تجاوز مشكلة initializer_range
if hasattr(Qwen3OmniMoeForConditionalGeneration, "_init_weights"):
Qwen3OmniMoeForConditionalGeneration._init_weights = _patched_init_weights
# Patch 3: تجاوز initialize_weights بالكامل إذا لزم
if hasattr(Qwen3OmniMoeForConditionalGeneration, "initialize_weights"):
# حفظ الدالة الأصلية
Qwen3OmniMoeForConditionalGeneration._original_initialize_weights = (
Qwen3OmniMoeForConditionalGeneration.initialize_weights
)
Qwen3OmniMoeForConditionalGeneration.initialize_weights = _patched_initialize_weights
# تطبيق الـ patches
apply_patches()
# =========================================================
# إعدادات عامة
# =========================================================
MODEL_PATH = os.getenv("MODEL_PATH", "Qwen/Qwen3-Omni-30B-A3B-Instruct")
USE_AUDIO_IN_VIDEO = True # استخدام الصوت داخل الفيديو إذا وجد
VOICE_CHOICES = ["Ethan", "Chelsie", "Aiden"]
DEFAULT_VOICE = "Ethan"
# سنحمّل النموذج كسولياً (عند أول استدعاء فقط)
model = None
processor = None
def load_model():
"""
تحميل Qwen3-Omni والمعالج عند أول استدعاء فقط.
- نستخدم attn_implementation="eager" لتفادي الحاجة لـ flash-attn.
- نضيف low_cpu_mem_usage=True لتحسين الأداء
- نضيف ignore_mismatched_sizes=True لتجاوز مشاكل الأحجام
"""
global model, processor
if model is not None and processor is not None:
return
print(f"[ZeroGPU] Loading model from: {MODEL_PATH}")
# اختيار نوع البيانات والجهاز
if torch.cuda.is_available():
torch_dtype = torch.bfloat16
device = "cuda"
else:
torch_dtype = torch.float32
device = "cpu"
try:
# محاولة تحميل النموذج مع خيارات إضافية للأمان
local_model = Qwen3OmniMoeForConditionalGeneration.from_pretrained(
MODEL_PATH,
torch_dtype=torch_dtype,
attn_implementation="eager", # آمن على ZeroGPU بدون flash-attn
low_cpu_mem_usage=True, # تحسين استخدام الذاكرة
ignore_mismatched_sizes=True, # تجاوز مشاكل الأحجام
trust_remote_code=True, # السماح بالكود المخصص
)
# نقل النموذج إلى الجهاز المناسب
local_model = local_model.to(device)
# وضع النموذج في وضع التقييم (inference)
local_model.eval()
except Exception as e:
print(f"Error loading model: {e}")
print("Attempting alternative loading method...")
# محاولة بديلة مع تعطيل _init_weights
try:
# تعطيل _init_weights مؤقتاً
original_init_weights = None
if hasattr(Qwen3OmniMoeForConditionalGeneration, "_init_weights"):
original_init_weights = Qwen3OmniMoeForConditionalGeneration._init_weights
Qwen3OmniMoeForConditionalGeneration._init_weights = lambda self, module: None
local_model = Qwen3OmniMoeForConditionalGeneration.from_pretrained(
MODEL_PATH,
torch_dtype=torch_dtype,
attn_implementation="eager",
low_cpu_mem_usage=True,
)
# استعادة _init_weights
if original_init_weights:
Qwen3OmniMoeForConditionalGeneration._init_weights = original_init_weights
local_model = local_model.to(device)
local_model.eval()
except Exception as e2:
raise RuntimeError(f"Failed to load model: {e2}")
# تحميل المعالج
try:
local_processor = Qwen3OmniMoeProcessor.from_pretrained(
MODEL_PATH,
trust_remote_code=True
)
except Exception as e:
print(f"Error loading processor: {e}")
raise
model = local_model
processor = local_processor
print(f"[ZeroGPU] Model loaded successfully on {device} with dtype {torch_dtype}.")
def build_messages_from_history(
history,
system_prompt,
user_text,
image,
audio_path,
video_path,
):
"""
تحويل تاريخ الدردشة + المدخل الحالي إلى conversation بالـ format
المطلوب من Qwen3-Omni.
history: list of [user_text, assistant_text]
"""
messages = []
if system_prompt:
messages.append(
{
"role": "system",
"content": [{"type": "text", "text": system_prompt}],
}
)
# تاريخ المحادثة
for user_msg, assistant_msg in history:
if user_msg:
messages.append(
{
"role": "user",
"content": [{"type": "text", "text": user_msg}],
}
)
if assistant_msg:
messages.append(
{
"role": "assistant",
"content": [{"type": "text", "text": assistant_msg}],
}
)
# محتوى رسالة المستخدم الحالية
user_content = []
if image is not None:
user_content.append({"type": "image", "image": image})
if audio_path is not None and audio_path != "":
user_content.append({"type": "audio", "audio": audio_path})
if video_path is not None and video_path != "":
user_content.append({"type": "video", "video": video_path})
if user_text and user_text.strip():
user_content.append({"type": "text", "text": user_text.strip()})
if user_content:
messages.append(
{
"role": "user",
"content": user_content,
}
)
return messages
# =========================================================
# دالة الاستدلال (تعمل على ZeroGPU)
# =========================================================
@spaces.GPU(duration=120)
def qwen3_omni_inference(
history,
user_text,
image,
audio_path,
video_path,
system_prompt,
return_audio,
speaker,
temperature,
top_p,
max_tokens,
):
"""
- تنفيذ الاستدلال على ZeroGPU.
- يدعم نص + صورة + صوت + فيديو في نفس الرسالة.
- مخرج نصي دائماً، ومخرج صوتي اختياري.
"""
# في حالة عدم وجود مداخل من المستخدم
if not (user_text or image is not None or audio_path or video_path):
return history, None, "", None, None, None
try:
load_model()
global model, processor
messages = build_messages_from_history(
history=history,
system_prompt=system_prompt,
user_text=user_text,
image=image,
audio_path=audio_path,
video_path=video_path,
)
# بناء نص المحادثة باستخدام chat_template
text_prompt = processor.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=False,
)
# تجهيز الوسائط المتعددة
audios, images, videos = process_mm_info(
messages,
use_audio_in_video=USE_AUDIO_IN_VIDEO,
)
# تحويل إلى تينسورات
inputs = processor(
text=text_prompt,
audio=audios,
images=images,
videos=videos,
return_tensors="pt",
padding=True,
use_audio_in_video=USE_AUDIO_IN_VIDEO,
)
# نقل إلى جهاز النموذج ونفس dtype
first_param = next(model.parameters())
device = first_param.device
# تحويل المدخلات إلى الجهاز المناسب
for key in inputs:
if hasattr(inputs[key], 'to'):
inputs[key] = inputs[key].to(device)
# إعدادات التوليد
gen_kwargs = dict(
temperature=float(temperature) if temperature > 0 else 1e-7,
top_p=float(top_p),
max_new_tokens=int(max_tokens),
do_sample=temperature > 0,
use_audio_in_video=USE_AUDIO_IN_VIDEO,
)
# إضافة thinker_return_dict_in_generate فقط إذا كان مدعوماً
if hasattr(model, 'config') and hasattr(model.config, 'thinker_return_dict_in_generate'):
gen_kwargs["thinker_return_dict_in_generate"] = True
# توليد نص فقط أو نص + صوت
with torch.no_grad():
if not return_audio:
gen_kwargs["return_audio"] = False
outputs = model.generate(**inputs, **gen_kwargs)
text_ids = outputs
audio_out = None
else:
gen_kwargs["speaker"] = speaker
gen_kwargs["return_audio"] = True
outputs = model.generate(**inputs, **gen_kwargs)
if isinstance(outputs, tuple):
text_ids, audio_out = outputs
else:
text_ids = outputs
audio_out = None
# استخراج النص الناتج (بدون مدخل prompt)
input_len = inputs["input_ids"].shape[1]
# التعامل مع الأنواع المختلفة من المخرجات
if hasattr(text_ids, 'sequences'):
generated_ids = text_ids.sequences[:, input_len:]
else:
generated_ids = text_ids[:, input_len:]
generated_text = processor.batch_decode(
generated_ids,
skip_special_tokens=True,
clean_up_tokenization_spaces=False,
)[0]
# تحديث تاريخ الدردشة
user_display = (
user_text if (user_text and user_text.strip()) else "[Multimodal message]"
)
history = history + [[user_display, generated_text]]
# تجهيز الصوت الناتج إن وجد
gr_audio = None
if audio_out is not None:
try:
audio_np = audio_out.reshape(-1).detach().cpu().numpy()
sample_rate = 24000
gr_audio = (sample_rate, audio_np.astype(np.float32))
except Exception as e:
print(f"Warning: Could not process audio output: {e}")
gr_audio = None
# نعيد: history الجديد + صوت الرد + تفريغ مدخلات المستخدم
return history, gr_audio, "", None, None, None
except Exception as e:
print(f"Error during inference: {e}")
import traceback
traceback.print_exc()
# في حالة الخطأ، نضيف رسالة خطأ للمحادثة
user_display = (
user_text if (user_text and user_text.strip()) else "[Multimodal message]"
)
error_message = f"عذراً، حدث خطأ أثناء معالجة الرسالة: {str(e)}"
history = history + [[user_display, error_message]]
return history, None, "", None, None, None
# =========================================================
# دوال واجهة Gradio
# =========================================================
def clear_chat():
"""إعادة تعيين المحادثة ومخرج الصوت."""
return [], None
def create_interface():
with gr.Blocks(
title="Qwen3-Omni-30B-A3B – ZeroGPU Chat",
theme=gr.themes.Soft(),
) as demo:
gr.Markdown(
"""
<h1 style="text-align:center;">🤖 Qwen3-Omni-30B-A3B – ZeroGPU Chat</h1>
<p style="text-align:center;">
دردشة متعددة الوسائط (نص + صورة + صوت + فيديو) تعمل على ZeroGPU.<br/>
اكتب رسالتك، ويمكنك إضافة صورة/صوت/فيديو، ثم اضغط <b>إرسال</b> أو Enter.<br/>
(لإضافة سطر جديد استخدم Shift+Enter)
</p>
"""
)
with gr.Row():
# العمود الأيسر: المحادثة
with gr.Column(scale=3):
chatbot = gr.Chatbot(
label="المحادثة",
height=480,
elem_id="chatbot",
)
audio_output = gr.Audio(
label="رد النموذج (صوت)",
type="numpy",
autoplay=True,
visible=True,
)
with gr.Row():
user_text = gr.Textbox(
label="رسالتك",
placeholder="اكتب رسالتك هنا (يمكنك أيضاً إرفاق صورة/صوت/فيديو من الأسفل)...",
lines=3,
show_label=False,
elem_id="message",
)
with gr.Row():
image_input = gr.Image(
label="📷 صورة (اختياري)",
type="pil",
sources=["upload", "webcam"],
height=150,
)
audio_input = gr.Audio(
label="🎙️ صوت (اختياري)",
type="filepath",
sources=["microphone", "upload"],
)
video_input = gr.Video(
label="🎬 فيديو (اختياري)",
height=150,
)
with gr.Row():
send_btn = gr.Button("إرسال", variant="primary", scale=2)
clear_btn = gr.Button("مسح المحادثة", variant="secondary")
# العمود الأيمن: الإعدادات
with gr.Column(scale=1):
gr.Markdown("### ⚙️ إعدادات النموذج")
system_prompt = gr.Textbox(
label="System Prompt",
value="You are a helpful, multilingual assistant.",
lines=4,
placeholder="يمكنك التحكم في شخصية النموذج من هنا (اختياري).",
)
return_audio = gr.Checkbox(
label="تفعيل مخرج صوتي (النموذج يتكلم)؟",
value=True,
)
speaker = gr.Dropdown(
label="صوت المتحدث (speaker)",
choices=VOICE_CHOICES,
value=DEFAULT_VOICE,
)
with gr.Accordion("إعدادات متقدمة", open=False):
temperature = gr.Slider(
label="Temperature (العشوائية)",
minimum=0.0,
maximum=1.5,
value=0.6,
step=0.05,
)
top_p = gr.Slider(
label="Top-p (حجم العينة)",
minimum=0.1,
maximum=1.0,
value=0.95,
step=0.05,
)
max_tokens = gr.Slider(
label="Max new tokens (طول الرد الأقصى)",
minimum=16,
maximum=1024,
value=384,
step=16,
)
gr.Markdown(
"""
**📝 ملاحظات:**
- يمكنك إرسال نص فقط، أو نص مع صورة/صوت/فيديو في رسالة واحدة
- Enter للإرسال، Shift+Enter لسطر جديد
- تشغيل النموذج على ZeroGPU قد يستغرق عدة ثوانٍ حسب طول الرسالة
- النموذج يدعم اللغات المتعددة بما فيها العربية والإنجليزية
"""
)
# حالة المحادثة
history_state = gr.State([])
# مدخلات دالة الإرسال
send_inputs = [
history_state,
user_text,
image_input,
audio_input,
video_input,
system_prompt,
return_audio,
speaker,
temperature,
top_p,
max_tokens,
]
send_outputs = [
history_state,
audio_output,
user_text,
image_input,
audio_input,
video_input,
]
# إرسال بالزر
send_btn.click(
fn=qwen3_omni_inference,
inputs=send_inputs,
outputs=send_outputs,
queue=True,
).then(
lambda h: h,
inputs=history_state,
outputs=chatbot,
)
# إرسال بالـ Enter من Textbox
user_text.submit(
fn=qwen3_omni_inference,
inputs=send_inputs,
outputs=send_outputs,
queue=True,
).then(
lambda h: h,
inputs=history_state,
outputs=chatbot,
)
# مسح المحادثة
clear_btn.click(
fn=clear_chat,
inputs=None,
outputs=[history_state, audio_output],
).then(
lambda: [],
inputs=None,
outputs=chatbot,
).then(
lambda: ("", None, None, None),
inputs=None,
outputs=[user_text, image_input, audio_input, video_input],
)
# رسالة تحميل النموذج عند بدء التطبيق
demo.load(
lambda: gr.Info("جاري تحميل النموذج... قد يستغرق هذا بضع دقائق في المرة الأولى."),
inputs=None,
outputs=None,
)
return demo
# إنشاء الواجهة
demo = create_interface()
if __name__ == "__main__":
# إطلاق التطبيق
demo.launch(
ssr_mode=False, # تعطيل SSR لتجنب مشاكل "Starting..."
show_error=True, # عرض الأخطاء بشكل واضح
)