|
|
import os |
|
|
import torch |
|
|
import gradio as gr |
|
|
import spaces |
|
|
import json |
|
|
import time |
|
|
from threading import Thread |
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer |
|
|
from huggingface_hub import login |
|
|
import logging |
|
|
|
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_config(): |
|
|
"""Load configuration from config.json""" |
|
|
try: |
|
|
with open("config.json", "r", encoding="utf-8") as f: |
|
|
return json.load(f) |
|
|
except FileNotFoundError: |
|
|
logger.warning("config.json not found, using default settings") |
|
|
return { |
|
|
"model": {"model_id": "anaspro/Lahja-iraqi-4B"}, |
|
|
"generation": { |
|
|
"max_new_tokens": 1024, |
|
|
"temperature": 0.7, |
|
|
"top_p": 0.9, |
|
|
"top_k": 50, |
|
|
"do_sample": True, |
|
|
"repetition_penalty": 1.1, |
|
|
"timeout_seconds": 60 |
|
|
}, |
|
|
"interface": {"max_context_length": 4096} |
|
|
} |
|
|
|
|
|
config = load_config() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
MODEL_ID = config["model"].get("model_id", "anaspro/Lahja-iraqi-4B") |
|
|
|
|
|
|
|
|
try: |
|
|
with open("system_prompt.txt", "r", encoding="utf-8") as f: |
|
|
SYSTEM_PROMPT = f.read() |
|
|
except FileNotFoundError: |
|
|
logger.warning("system_prompt.txt not found, using default prompt") |
|
|
SYSTEM_PROMPT = "أنت مساعد ذكي مفيد. تحدث بالعربية وساعد المستخدم في استفساراته." |
|
|
|
|
|
|
|
|
if os.getenv("HF_TOKEN"): |
|
|
login(token=os.getenv("HF_TOKEN")) |
|
|
logger.info("🔐 Logged in to Hugging Face") |
|
|
|
|
|
|
|
|
model = None |
|
|
tokenizer = None |
|
|
model_lock = False |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_model(): |
|
|
"""Load the model and tokenizer with proper error handling""" |
|
|
global model, tokenizer, model_lock |
|
|
|
|
|
if model_lock: |
|
|
logger.info("Model loading already in progress...") |
|
|
return False |
|
|
|
|
|
model_lock = True |
|
|
try: |
|
|
logger.info("🔄 Loading model...") |
|
|
|
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained( |
|
|
MODEL_ID, |
|
|
trust_remote_code=True, |
|
|
use_fast=True |
|
|
) |
|
|
|
|
|
|
|
|
if tokenizer.pad_token is None: |
|
|
tokenizer.pad_token = tokenizer.eos_token |
|
|
|
|
|
|
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
|
MODEL_ID, |
|
|
torch_dtype=torch.bfloat16, |
|
|
device_map="auto", |
|
|
trust_remote_code=True, |
|
|
attn_implementation="flash_attention_2" if torch.cuda.is_available() else None, |
|
|
low_cpu_mem_usage=True |
|
|
) |
|
|
|
|
|
model.eval() |
|
|
|
|
|
|
|
|
if torch.cuda.is_available(): |
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
logger.info("✅ Model loaded successfully!") |
|
|
return True |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"❌ Error loading model: {str(e)}") |
|
|
return False |
|
|
finally: |
|
|
model_lock = False |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@spaces.GPU(duration=120) |
|
|
def chat(message, history): |
|
|
"""Main chat function with improved error handling and conversation management""" |
|
|
global model, tokenizer |
|
|
|
|
|
|
|
|
if model is None or tokenizer is None: |
|
|
if not load_model(): |
|
|
return "❌ عذراً، حدث خطأ في تحميل النموذج. يرجى المحاولة مرة أخرى." |
|
|
|
|
|
try: |
|
|
|
|
|
|
|
|
|
|
|
messages = [{"role": "system", "content": SYSTEM_PROMPT}] |
|
|
|
|
|
|
|
|
if history: |
|
|
for exchange in history: |
|
|
if isinstance(exchange, dict): |
|
|
|
|
|
if exchange.get("role") == "user": |
|
|
messages.append({"role": "user", "content": exchange.get("content", "")}) |
|
|
elif exchange.get("role") == "assistant": |
|
|
messages.append({"role": "assistant", "content": exchange.get("content", "")}) |
|
|
elif isinstance(exchange, (list, tuple)) and len(exchange) >= 2: |
|
|
|
|
|
if exchange[0]: |
|
|
messages.append({"role": "user", "content": str(exchange[0])}) |
|
|
if exchange[1]: |
|
|
messages.append({"role": "assistant", "content": str(exchange[1])}) |
|
|
|
|
|
|
|
|
if message and message.strip(): |
|
|
messages.append({"role": "user", "content": message.strip()}) |
|
|
else: |
|
|
return "يرجى كتابة رسالة صحيحة." |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
try: |
|
|
max_length = config.get("interface", {}).get("max_context_length", 4096) |
|
|
input_ids = tokenizer.apply_chat_template( |
|
|
messages, |
|
|
return_tensors="pt", |
|
|
add_generation_prompt=True, |
|
|
truncation=True, |
|
|
max_length=max_length |
|
|
).to(model.device) |
|
|
except Exception as e: |
|
|
logger.error(f"Tokenization error: {e}") |
|
|
return "❌ خطأ في معالجة الرسالة. يرجى المحاولة مرة أخرى." |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
streamer = TextIteratorStreamer( |
|
|
tokenizer, |
|
|
skip_prompt=True, |
|
|
skip_special_tokens=True, |
|
|
clean_up_tokenization_spaces=True |
|
|
) |
|
|
|
|
|
generation_config = config.get("generation", {}) |
|
|
generation_kwargs = { |
|
|
"input_ids": input_ids, |
|
|
"streamer": streamer, |
|
|
"max_new_tokens": generation_config.get("max_new_tokens", 1024), |
|
|
"temperature": generation_config.get("temperature", 0.7), |
|
|
"top_p": generation_config.get("top_p", 0.9), |
|
|
"top_k": generation_config.get("top_k", 50), |
|
|
"do_sample": generation_config.get("do_sample", True), |
|
|
"repetition_penalty": generation_config.get("repetition_penalty", 1.1), |
|
|
"pad_token_id": tokenizer.pad_token_id, |
|
|
"eos_token_id": tokenizer.eos_token_id, |
|
|
"use_cache": True |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
thread = Thread(target=model.generate, kwargs=generation_kwargs) |
|
|
thread.daemon = True |
|
|
thread.start() |
|
|
|
|
|
partial_text = "" |
|
|
start_time = time.time() |
|
|
timeout = config.get("generation", {}).get("timeout_seconds", 60) |
|
|
|
|
|
try: |
|
|
for new_text in streamer: |
|
|
if time.time() - start_time > timeout: |
|
|
logger.warning("Generation timeout reached") |
|
|
break |
|
|
|
|
|
partial_text += new_text |
|
|
yield partial_text |
|
|
except Exception as e: |
|
|
logger.error(f"Generation error: {e}") |
|
|
yield "❌ حدث خطأ أثناء توليد الإجابة. يرجى المحاولة مرة أخرى." |
|
|
|
|
|
thread.join(timeout=5) |
|
|
|
|
|
|
|
|
if torch.cuda.is_available(): |
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Chat function error: {e}") |
|
|
return f"❌ حدث خطأ غير متوقع: {str(e)}" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def create_interface(): |
|
|
"""Create the Gradio interface with enhanced UI""" |
|
|
|
|
|
|
|
|
custom_css = """ |
|
|
.gradio-container { |
|
|
max-width: 1000px !important; |
|
|
margin: auto !important; |
|
|
} |
|
|
.chat-message { |
|
|
padding: 10px !important; |
|
|
margin: 5px 0 !important; |
|
|
border-radius: 10px !important; |
|
|
} |
|
|
.message { |
|
|
font-size: 16px !important; |
|
|
line-height: 1.5 !important; |
|
|
} |
|
|
.title { |
|
|
text-align: center !important; |
|
|
color: #2563eb !important; |
|
|
margin-bottom: 20px !important; |
|
|
} |
|
|
.description { |
|
|
text-align: center !important; |
|
|
margin-bottom: 30px !important; |
|
|
color: #6b7280 !important; |
|
|
} |
|
|
""" |
|
|
|
|
|
with gr.Blocks( |
|
|
css=custom_css, |
|
|
theme=gr.themes.Soft( |
|
|
primary_hue="blue", |
|
|
secondary_hue="gray", |
|
|
neutral_hue="slate" |
|
|
), |
|
|
title="دعم فني - NB TEL" |
|
|
) as demo: |
|
|
|
|
|
gr.Markdown( |
|
|
""" |
|
|
# 📞 دعم فني - NB TEL Internet Assistant |
|
|
|
|
|
**مساعد ذكي لخدمة الدعم الفني في شبكة النور - NB TEL** |
|
|
|
|
|
تحدث معه كأنك زبون: اشرح مشكلتك، اسأل عن الباقات، أو اطلب تذكرة دعم. |
|
|
""", |
|
|
elem_classes=["title", "description"] |
|
|
) |
|
|
|
|
|
|
|
|
chatbot = gr.ChatInterface( |
|
|
fn=chat, |
|
|
type="messages", |
|
|
examples=[ |
|
|
["الإنترنت عندي مقطوع من الصبح، شنو السبب؟"], |
|
|
["أريد أرقّي الباقة إلى 50 ميج."], |
|
|
["ضوء الـ LOS في جهاز الفايبر أحمر، شنو معناها؟"], |
|
|
["كم سعر باقة الإنترنت اللامحدود؟"], |
|
|
["المودم يفصل ويوصل باستمرار، شنو الحل؟"] |
|
|
], |
|
|
cache_examples=False, |
|
|
retry_btn="🔄 إعادة المحاولة", |
|
|
undo_btn="↶ تراجع", |
|
|
clear_btn="🗑️ مسح المحادثة", |
|
|
submit_btn="إرسال 📤", |
|
|
textbox=gr.Textbox( |
|
|
placeholder="اكتب استفسارك هنا... 💬", |
|
|
container=False, |
|
|
scale=7 |
|
|
) |
|
|
) |
|
|
|
|
|
|
|
|
gr.Markdown( |
|
|
""" |
|
|
--- |
|
|
**ملاحظة:** هذا مساعد ذكي للمحاكاة. البيانات المعروضة هي للتدريب فقط. |
|
|
|
|
|
**الباقات المتاحة:** |
|
|
- 🏠 HOME-10M: 10 Mbps - $9.99/شهر |
|
|
- 🏠 HOME-50M: 50 Mbps - $19.99/شهر |
|
|
- 🏢 BUS-200M: 200 Mbps - $69.99/شهر |
|
|
- ⚡ UNL-1G: 1 Gbps غير محدود - $149.99/شهر |
|
|
""", |
|
|
elem_classes=["description"] |
|
|
) |
|
|
|
|
|
return demo |
|
|
|
|
|
|
|
|
demo = create_interface() |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch() |
|
|
|