import os import torch import gradio as gr import spaces import json import time from threading import Thread from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer from huggingface_hub import login import logging # Setup logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # ====================================================== # Load Configuration # ====================================================== def load_config(): """Load configuration from config.json""" try: with open("config.json", "r", encoding="utf-8") as f: return json.load(f) except FileNotFoundError: logger.warning("config.json not found, using default settings") return { "model": {"model_id": "anaspro/Lahja-iraqi-4B"}, "generation": { "max_new_tokens": 1024, "temperature": 0.7, "top_p": 0.9, "top_k": 50, "do_sample": True, "repetition_penalty": 1.1, "timeout_seconds": 60 }, "interface": {"max_context_length": 4096} } config = load_config() # ====================================================== # Settings # ====================================================== MODEL_ID = config["model"].get("model_id", "anaspro/Lahja-iraqi-4B") # Load system prompt from external file try: with open("system_prompt.txt", "r", encoding="utf-8") as f: SYSTEM_PROMPT = f.read() except FileNotFoundError: logger.warning("system_prompt.txt not found, using default prompt") SYSTEM_PROMPT = "أنت مساعد ذكي مفيد. تحدث بالعربية وساعد المستخدم في استفساراته." # Login to Hugging Face if os.getenv("HF_TOKEN"): login(token=os.getenv("HF_TOKEN")) logger.info("🔐 Logged in to Hugging Face") # Global model variables model = None tokenizer = None model_lock = False # ====================================================== # Model loading function # ====================================================== def load_model(): """Load the model and tokenizer with proper error handling""" global model, tokenizer, model_lock if model_lock: logger.info("Model loading already in progress...") return False model_lock = True try: logger.info("🔄 Loading model...") # Load tokenizer first tokenizer = AutoTokenizer.from_pretrained( MODEL_ID, trust_remote_code=True, use_fast=True ) # Add padding token if missing if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token # Load model with optimized settings model = AutoModelForCausalLM.from_pretrained( MODEL_ID, torch_dtype=torch.bfloat16, device_map="auto", trust_remote_code=True, attn_implementation="flash_attention_2" if torch.cuda.is_available() else None, low_cpu_mem_usage=True ) model.eval() # Clear cache to free memory if torch.cuda.is_available(): torch.cuda.empty_cache() logger.info("✅ Model loaded successfully!") return True except Exception as e: logger.error(f"❌ Error loading model: {str(e)}") return False finally: model_lock = False # ====================================================== # Chat function (ZeroGPU) # ====================================================== @spaces.GPU(duration=120) def chat(message, history): """Main chat function with improved error handling and conversation management""" global model, tokenizer # Load model if not already loaded if model is None or tokenizer is None: if not load_model(): return "❌ عذراً، حدث خطأ في تحميل النموذج. يرجى المحاولة مرة أخرى." try: # ====================================================== # Build conversation properly # ====================================================== messages = [{"role": "system", "content": SYSTEM_PROMPT}] # Process conversation history correctly if history: for exchange in history: if isinstance(exchange, dict): # Handle message format from Gradio if exchange.get("role") == "user": messages.append({"role": "user", "content": exchange.get("content", "")}) elif exchange.get("role") == "assistant": messages.append({"role": "assistant", "content": exchange.get("content", "")}) elif isinstance(exchange, (list, tuple)) and len(exchange) >= 2: # Handle [user_msg, assistant_msg] format if exchange[0]: # User message messages.append({"role": "user", "content": str(exchange[0])}) if exchange[1]: # Assistant message messages.append({"role": "assistant", "content": str(exchange[1])}) # Add current user message if message and message.strip(): messages.append({"role": "user", "content": message.strip()}) else: return "يرجى كتابة رسالة صحيحة." # ====================================================== # Tokenize input with error handling # ====================================================== try: max_length = config.get("interface", {}).get("max_context_length", 4096) input_ids = tokenizer.apply_chat_template( messages, return_tensors="pt", add_generation_prompt=True, truncation=True, max_length=max_length ).to(model.device) except Exception as e: logger.error(f"Tokenization error: {e}") return "❌ خطأ في معالجة الرسالة. يرجى المحاولة مرة أخرى." # ====================================================== # Setup text streamer # ====================================================== streamer = TextIteratorStreamer( tokenizer, skip_prompt=True, skip_special_tokens=True, clean_up_tokenization_spaces=True ) generation_config = config.get("generation", {}) generation_kwargs = { "input_ids": input_ids, "streamer": streamer, "max_new_tokens": generation_config.get("max_new_tokens", 1024), "temperature": generation_config.get("temperature", 0.7), "top_p": generation_config.get("top_p", 0.9), "top_k": generation_config.get("top_k", 50), "do_sample": generation_config.get("do_sample", True), "repetition_penalty": generation_config.get("repetition_penalty", 1.1), "pad_token_id": tokenizer.pad_token_id, "eos_token_id": tokenizer.eos_token_id, "use_cache": True } # ====================================================== # Generate output in a separate thread with timeout # ====================================================== thread = Thread(target=model.generate, kwargs=generation_kwargs) thread.daemon = True thread.start() partial_text = "" start_time = time.time() timeout = config.get("generation", {}).get("timeout_seconds", 60) try: for new_text in streamer: if time.time() - start_time > timeout: logger.warning("Generation timeout reached") break partial_text += new_text yield partial_text except Exception as e: logger.error(f"Generation error: {e}") yield "❌ حدث خطأ أثناء توليد الإجابة. يرجى المحاولة مرة أخرى." thread.join(timeout=5) # Give thread 5 seconds to finish # Clear GPU cache after generation if torch.cuda.is_available(): torch.cuda.empty_cache() except Exception as e: logger.error(f"Chat function error: {e}") return f"❌ حدث خطأ غير متوقع: {str(e)}" # ====================================================== # Gradio Interface with enhanced styling # ====================================================== def create_interface(): """Create the Gradio interface with enhanced UI""" # Custom CSS for better styling custom_css = """ .gradio-container { max-width: 1000px !important; margin: auto !important; } .chat-message { padding: 10px !important; margin: 5px 0 !important; border-radius: 10px !important; } .message { font-size: 16px !important; line-height: 1.5 !important; } .title { text-align: center !important; color: #2563eb !important; margin-bottom: 20px !important; } .description { text-align: center !important; margin-bottom: 30px !important; color: #6b7280 !important; } """ with gr.Blocks( css=custom_css, theme=gr.themes.Soft( primary_hue="blue", secondary_hue="gray", neutral_hue="slate" ), title="دعم فني - NB TEL" ) as demo: gr.Markdown( """ # 📞 دعم فني - NB TEL Internet Assistant **مساعد ذكي لخدمة الدعم الفني في شبكة النور - NB TEL** تحدث معه كأنك زبون: اشرح مشكلتك، اسأل عن الباقات، أو اطلب تذكرة دعم. """, elem_classes=["title", "description"] ) # Chat interface chatbot = gr.ChatInterface( fn=chat, type="messages", examples=[ ["الإنترنت عندي مقطوع من الصبح، شنو السبب؟"], ["أريد أرقّي الباقة إلى 50 ميج."], ["ضوء الـ LOS في جهاز الفايبر أحمر، شنو معناها؟"], ["كم سعر باقة الإنترنت اللامحدود؟"], ["المودم يفصل ويوصل باستمرار، شنو الحل؟"] ], cache_examples=False, retry_btn="🔄 إعادة المحاولة", undo_btn="↶ تراجع", clear_btn="🗑️ مسح المحادثة", submit_btn="إرسال 📤", textbox=gr.Textbox( placeholder="اكتب استفسارك هنا... 💬", container=False, scale=7 ) ) # Footer with information gr.Markdown( """ --- **ملاحظة:** هذا مساعد ذكي للمحاكاة. البيانات المعروضة هي للتدريب فقط. **الباقات المتاحة:** - 🏠 HOME-10M: 10 Mbps - $9.99/شهر - 🏠 HOME-50M: 50 Mbps - $19.99/شهر - 🏢 BUS-200M: 200 Mbps - $69.99/شهر - ⚡ UNL-1G: 1 Gbps غير محدود - $149.99/شهر """, elem_classes=["description"] ) return demo # Create the interface demo = create_interface() if __name__ == "__main__": demo.launch()