anaspro commited on
Commit
40db06d
·
1 Parent(s): 0cb7257
Files changed (1) hide show
  1. app.py +320 -16
app.py CHANGED
@@ -1,21 +1,325 @@
1
- from transformers import pipeline
2
  import torch
 
 
 
 
 
 
 
 
3
 
4
- model_id = "openai/gpt-oss-20b"
 
 
5
 
6
- pipe = pipeline(
7
- "text-generation",
8
- model=model_id,
9
- torch_dtype="auto",
10
- device_map="auto",
11
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
- messages = [
14
- {"role": "user", "content": "Explain quantum mechanics clearly and concisely."},
15
- ]
16
 
17
- outputs = pipe(
18
- messages,
19
- max_new_tokens=256,
20
- )
21
- print(outputs[0]["generated_text"][-1])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
  import torch
3
+ import gradio as gr
4
+ import spaces
5
+ import json
6
+ import time
7
+ from threading import Thread
8
+ from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer, BitsAndBytesConfig
9
+ from huggingface_hub import login
10
+ import logging
11
 
12
+ # Setup logging
13
+ logging.basicConfig(level=logging.INFO)
14
+ logger = logging.getLogger(__name__)
15
 
16
+ # ======================================================
17
+ # Load Configuration
18
+ # ======================================================
19
+ def load_config():
20
+ """Load configuration from config.json"""
21
+ try:
22
+ with open("config.json", "r", encoding="utf-8") as f:
23
+ return json.load(f)
24
+ except FileNotFoundError:
25
+ logger.warning("config.json not found, using default settings")
26
+ return {
27
+ "model": {"model_id": "unsloth/gpt-oss-20b-GGUF"},
28
+ "generation": {
29
+ "max_new_tokens": 1024,
30
+ "temperature": 1,
31
+ "top_p": 0.95,
32
+ "top_k": 64,
33
+ "do_sample": True,
34
+ "repetition_penalty": 1.1,
35
+ "timeout_seconds": 60
36
+ },
37
+ "interface": {"max_context_length": 4096}
38
+ }
39
 
40
+ config = load_config()
 
 
41
 
42
+ # ======================================================
43
+ # Settings
44
+ # ======================================================
45
+ MODEL_ID = config["model"].get("model_id", "anaspro/Lahja-iraqi-4B")
46
+
47
+ # Load system prompt from external file
48
+ try:
49
+ with open("system_prompt.txt", "r", encoding="utf-8") as f:
50
+ SYSTEM_PROMPT = f.read()
51
+ except FileNotFoundError:
52
+ logger.warning("system_prompt.txt not found, using default prompt")
53
+ SYSTEM_PROMPT = "أنت مساعد ذكي مفيد. تحدث بالعربية وساعد المستخدم في استفساراته."
54
+
55
+ # Login to Hugging Face
56
+ if os.getenv("HF_TOKEN"):
57
+ login(token=os.getenv("HF_TOKEN"))
58
+ logger.info("🔐 Logged in to Hugging Face")
59
+
60
+ # Global model variables
61
+ model = None
62
+ tokenizer = None
63
+ model_lock = False
64
+
65
+ # ======================================================
66
+ # Model loading function
67
+ # ======================================================
68
+ def load_model():
69
+ """Load the model and tokenizer with proper error handling"""
70
+ global model, tokenizer, model_lock
71
+
72
+ if model_lock:
73
+ logger.info("Model loading already in progress...")
74
+ return False
75
+
76
+ model_lock = True
77
+ try:
78
+ logger.info("🔄 Loading model...")
79
+
80
+ # Load tokenizer first
81
+ tokenizer = AutoTokenizer.from_pretrained(
82
+ MODEL_ID,
83
+ trust_remote_code=True,
84
+ use_fast=True
85
+ )
86
+
87
+ # Add padding token if missing
88
+ if tokenizer.pad_token is None:
89
+ tokenizer.pad_token = tokenizer.eos_token
90
+
91
+ # Configure 4-bit quantization
92
+ if config["model"].get("load_in_4bit", False):
93
+ quantization_config = BitsAndBytesConfig(
94
+ load_in_4bit=True,
95
+ bnb_4bit_compute_dtype=torch.float16,
96
+ bnb_4bit_use_double_quant=True,
97
+ bnb_4bit_quant_type="nf4"
98
+ )
99
+ else:
100
+ quantization_config = None
101
+
102
+ # Load model with optimized settings
103
+ model = AutoModelForCausalLM.from_pretrained(
104
+ MODEL_ID,
105
+ torch_dtype=config["model"].get("torch_dtype", "auto"),
106
+ device_map=config["model"].get("device_map", "auto"),
107
+ trust_remote_code=config["model"].get("trust_remote_code", True),
108
+ low_cpu_mem_usage=config["model"].get("low_cpu_mem_usage", True),
109
+ quantization_config=quantization_config
110
+ )
111
+
112
+ model.eval()
113
+
114
+ # Clear cache to free memory
115
+ if torch.cuda.is_available():
116
+ torch.cuda.empty_cache()
117
+
118
+ logger.info("✅ Model loaded successfully!")
119
+ return True
120
+
121
+ except Exception as e:
122
+ logger.error(f"❌ Error loading model: {str(e)}")
123
+ return False
124
+ finally:
125
+ model_lock = False
126
+
127
+ # ======================================================
128
+ # Chat function (ZeroGPU)
129
+ # ======================================================
130
+ @spaces.GPU(duration=120)
131
+ def chat(message, history):
132
+ """Main chat function with improved error handling and conversation management"""
133
+ global model, tokenizer
134
+
135
+ # Check if model is loaded
136
+ if model is None or tokenizer is None:
137
+ return "❌ عذراً، النموذج لم يتم تحميله بعد. يرجى الا��تظار قليلاً والمحاولة مرة أخرى."
138
+
139
+ try:
140
+ # ======================================================
141
+ # Build conversation properly
142
+ # ======================================================
143
+ messages = [{"role": "system", "content": SYSTEM_PROMPT}]
144
+
145
+ # Process conversation history correctly
146
+ if history:
147
+ for exchange in history:
148
+ if isinstance(exchange, dict):
149
+ # Handle message format from Gradio
150
+ if exchange.get("role") == "user":
151
+ messages.append({"role": "user", "content": exchange.get("content", "")})
152
+ elif exchange.get("role") == "assistant":
153
+ messages.append({"role": "assistant", "content": exchange.get("content", "")})
154
+ elif isinstance(exchange, (list, tuple)) and len(exchange) >= 2:
155
+ # Handle [user_msg, assistant_msg] format
156
+ if exchange[0]: # User message
157
+ messages.append({"role": "user", "content": str(exchange[0])})
158
+ if exchange[1]: # Assistant message
159
+ messages.append({"role": "assistant", "content": str(exchange[1])})
160
+
161
+ # Add current user message
162
+ if message and message.strip():
163
+ # فلتر للتأكد من أن الموضوع متعلق بالإنترنت
164
+ internet_keywords = ["نت", "انترنت", "مودم", "wifi", "باقة", "سرعة", "كابل", "راوتر", "فايبر", "اتصال", "شبكة", "تحميل", "رفع", "ميجا", "جيجا"]
165
+ message_lower = message.lower()
166
+
167
+ # إذا الرسالة تحتوي على كلمات متعلقة بالإنترنت أو أسئلة عامة قصيرة
168
+ has_internet_keywords = any(keyword in message_lower for keyword in internet_keywords)
169
+ is_short_question = len(message.strip()) < 50 # الأسئلة القصيرة مسموحة
170
+
171
+ if has_internet_keywords or is_short_question:
172
+ messages.append({"role": "user", "content": message.strip()})
173
+ else:
174
+ return "آسف، انا هنا حتى اساعدك بمشاكل النت والباقات بس. شنو مشكلتك بالإنترنت؟"
175
+ else:
176
+ return "يرجى كتابة رسالة صحيحة."
177
+
178
+ # ======================================================
179
+ # Tokenize input with error handling
180
+ # ======================================================
181
+ try:
182
+ max_length = config.get("interface", {}).get("max_context_length", 4096)
183
+ input_ids = tokenizer.apply_chat_template(
184
+ messages,
185
+ return_tensors="pt",
186
+ add_generation_prompt=True,
187
+ truncation=True,
188
+ max_length=max_length
189
+ ).to(model.device)
190
+ except Exception as e:
191
+ logger.error(f"Tokenization error: {e}")
192
+ return "❌ خطأ في معالجة الرسالة. يرجى المحاولة مرة أخرى."
193
+
194
+ # ======================================================
195
+ # Setup text streamer
196
+ # ======================================================
197
+ streamer = TextIteratorStreamer(
198
+ tokenizer,
199
+ skip_prompt=True,
200
+ skip_special_tokens=True,
201
+ clean_up_tokenization_spaces=True
202
+ )
203
+
204
+ generation_config = config.get("generation", {})
205
+ generation_kwargs = {
206
+ "input_ids": input_ids,
207
+ "streamer": streamer,
208
+ "max_new_tokens": generation_config.get("max_new_tokens", 800), # تقليل أكثر لمنع الهلوسة
209
+ "min_new_tokens": 15, # حد أدنى معقول
210
+ "temperature": generation_config.get("temperature", 0.6), # تقليل العشوائية أكثر
211
+ "top_p": generation_config.get("top_p", 0.85), # تقليل التنوع للتحكم
212
+ "top_k": generation_config.get("top_k", 30), # تشديد القيود
213
+ "do_sample": generation_config.get("do_sample", True),
214
+ "repetition_penalty": generation_config.get("repetition_penalty", 1.15), # زيادة عقوبة التكرار
215
+ "no_repeat_ngram_size": 4, # منع تكرار العبارات الأطول
216
+ "early_stopping": True, # توقف مبكر للجمل المكتملة
217
+ "pad_token_id": tokenizer.pad_token_id,
218
+ "eos_token_id": tokenizer.eos_token_id,
219
+ "use_cache": True
220
+ }
221
+
222
+ # ======================================================
223
+ # Generate output in a separate thread with timeout
224
+ # ======================================================
225
+ thread = Thread(target=model.generate, kwargs=generation_kwargs)
226
+ thread.daemon = True
227
+ thread.start()
228
+
229
+ partial_text = ""
230
+ start_time = time.time()
231
+ timeout = config.get("generation", {}).get("timeout_seconds", 60)
232
+
233
+ # كلمات تشير إلى بداية حوار جديد
234
+ dialogue_indicators = ["👤", "🤖", "العميل:", "الزبون:", "المساعد:", "العضو:", "السؤال:", "الجواب:"]
235
+
236
+ try:
237
+ for new_text in streamer:
238
+ if time.time() - start_time > timeout:
239
+ logger.warning("Generation timeout reached")
240
+ break
241
+
242
+ partial_text += new_text
243
+
244
+ # إيقاف التوليد إذا بدأ النموذج بكتابة حوار
245
+ for indicator in dialogue_indicators:
246
+ if indicator in partial_text[50:]: # تجاهل أول 50 حرف
247
+ logger.info("Stopping generation - dialogue detected")
248
+ return partial_text[:partial_text.find(indicator, 50)].strip()
249
+
250
+ yield partial_text
251
+ except Exception as e:
252
+ logger.error(f"Generation error: {e}")
253
+ yield "❌ حدث خطأ أثناء توليد الإجابة. يرجى المحاولة مرة أخرى."
254
+
255
+ thread.join(timeout=5) # Give thread 5 seconds to finish
256
+
257
+ # Clear GPU cache after generation
258
+ if torch.cuda.is_available():
259
+ torch.cuda.empty_cache()
260
+
261
+ except Exception as e:
262
+ logger.error(f"Chat function error: {e}")
263
+ return f"❌ حدث خطأ غير متوقع: {str(e)}"
264
+
265
+
266
+ # ======================================================
267
+ # Gradio Interface with enhanced styling
268
+ # ======================================================
269
+ def create_interface():
270
+ """Create the Gradio interface with enhanced UI"""
271
+
272
+ # Custom CSS for better styling
273
+ custom_css = """
274
+ .gradio-container {
275
+ max-width: 1000px !important;
276
+ margin: auto !important;
277
+ }
278
+ .chat-message {
279
+ padding: 10px !important;
280
+ margin: 5px 0 !important;
281
+ border-radius: 10px !important;
282
+ }
283
+ .message {
284
+ font-size: 16px !important;
285
+ line-height: 1.5 !important;
286
+ }
287
+ """
288
+
289
+ # Create a simpler interface for better compatibility
290
+ demo = gr.ChatInterface(
291
+ fn=chat,
292
+ type="messages",
293
+ title="📞 دعم فني - NB TEL مساعد عراقي",
294
+ description="**مساعد ذكي متقدم يعتمد على GPT-OSS-20B من OpenAI للدعم الفني بشبكة النور - NB TEL**\n\n✨ قدرات متقدمة: تفكير منطقي، حلول خطوة بخطوة، تحليل شامل\n\nاحجي معاه كأنك زبون: اشرح مشكلتك، اسأل عن الباقات، او اطلب تذكرة دعم.",
295
+ examples=[
296
+ ["النت عندي بطيء جداً رغم باقة 100 ميجا. شرحلي الأسباب المحتملة والحلول."],
297
+ ["أريد فهم ليش النت بطيء. شرحلي خطوة بخطوة الأسباب والحلول."],
298
+ ["كم سعر باقة 60 ميجا وما هي مزاياها؟"],
299
+ ["جهازي يظهر متصل بس المواقع ما تفتح. ساعدني أشخيص المشكلة."],
300
+ ["أنا صاحب مؤسسة، أي باقة تناسب 10 موظفين وكم التكلفة؟"],
301
+ ["شلون اغير كلمة مرور الواي فاي خطوة بخطوة؟"],
302
+ ["النت ينقطع فجأة ويعود. ما السبب وكيف أصلحه؟"]
303
+ ],
304
+ cache_examples=False,
305
+ theme=gr.themes.Soft(
306
+ primary_hue="blue",
307
+ secondary_hue="gray",
308
+ neutral_hue="slate"
309
+ ),
310
+ css=custom_css
311
+ )
312
+
313
+ return demo
314
+
315
+ # ======================================================
316
+ # Load model on startup (before creating interface)
317
+ # ======================================================
318
+ logger.info("🚀 Starting application - loading model...")
319
+ load_model()
320
+
321
+ # Create the interface
322
+ demo = create_interface()
323
+
324
+ if __name__ == "__main__":
325
+ demo.launch()