anaspro commited on
Commit
8a6b1b9
·
1 Parent(s): 6da46a3
Files changed (1) hide show
  1. app.py +24 -89
app.py CHANGED
@@ -15,7 +15,7 @@ def load_system_prompt():
15
 
16
  DEFAULT_SYSTEM_PROMPT = load_system_prompt()
17
 
18
- model_path = "anaspro/Lahja-iraqi-4B"
19
 
20
  # إذا كان فيه HF_TOKEN في البيئة
21
  hf_token = os.getenv("HF_TOKEN")
@@ -96,33 +96,6 @@ def format_conversation_history(chat_history):
96
 
97
  @spaces.GPU()
98
  def generate_response(input_data, chat_history, max_new_tokens, temperature, top_p, top_k, repetition_penalty):
99
- # Test بسيط أولاً
100
- try:
101
- # رسالة test بسيطة
102
- test_prompt = "السلام عليكم"
103
- inputs = tokenizer(test_prompt, return_tensors="pt").to(model.device)
104
-
105
- print(f"Input shape: {inputs.input_ids.shape}") # Debug
106
- print(f"Input tokens: {inputs.input_ids[0][:10]}") # Debug
107
-
108
- with torch.no_grad():
109
- outputs = model.generate(
110
- **inputs,
111
- max_new_tokens=50, # قصير للاختبار
112
- do_sample=False,
113
- num_beams=1,
114
- pad_token_id=tokenizer.eos_token_id,
115
- eos_token_id=tokenizer.eos_token_id,
116
- )
117
-
118
- test_response = tokenizer.decode(outputs[0], skip_special_tokens=True)
119
- print(f"Test response: {test_response}") # Debug
120
-
121
- except Exception as e:
122
- print(f"Test failed: {e}")
123
- import traceback
124
- print(traceback.format_exc())
125
-
126
  # Build messages for Llama chat template
127
  messages = [{"role": "system", "content": DEFAULT_SYSTEM_PROMPT}]
128
 
@@ -135,74 +108,36 @@ def generate_response(input_data, chat_history, max_new_tokens, temperature, top
135
  # Add current user message
136
  messages.append({"role": "user", "content": input_data})
137
 
138
- # استخدام generate مباشرة مع parameters أكثر أماناً
139
- try:
140
- # محاولة استخدام chat template
141
- if hasattr(tokenizer, 'apply_chat_template') and tokenizer.chat_template is not None:
142
- prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
143
- print(f"Using chat template. Prompt length: {len(prompt)}") # Debug
144
- else:
145
- # Fallback format
146
- prompt = f"System: {DEFAULT_SYSTEM_PROMPT}\n\n"
147
- for msg in messages[1:]: # Skip system message since we added it above
148
- if msg["role"] == "user":
149
- prompt += f"Human: {msg['content']}\n"
150
- elif msg["role"] == "assistant":
151
- prompt += f"Assistant: {msg['content']}\n"
152
- prompt += "Assistant:"
153
- print(f"Using fallback format. Prompt length: {len(prompt)}") # Debug
154
-
155
- print(f"Final prompt: {prompt[:200]}...") # Debug first 200 chars
156
-
157
- inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
158
- print(f"Tokenized input shape: {inputs.input_ids.shape}") # Debug
159
-
160
- # استخدام generate مع parameters أساسية وآمنة
161
- with torch.no_grad():
162
- outputs = model.generate(
163
- **inputs,
164
- max_new_tokens=min(max_new_tokens, 512), # حد أقصى أمان
165
- do_sample=False, # تعطيل sampling للأمان
166
- num_beams=1, # greedy decoding
167
- pad_token_id=tokenizer.eos_token_id,
168
- eos_token_id=tokenizer.eos_token_id,
169
- return_dict_in_generate=True,
170
- output_scores=False,
171
- )
172
-
173
- print(f"Generated sequence shape: {outputs.sequences.shape}") # Debug
174
- print(f"Input length: {inputs.input_ids.shape[1]}") # Debug
175
-
176
- response = tokenizer.decode(outputs.sequences[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
177
- response = response.strip()
178
-
179
- print(f"Generated response length: {len(response)}") # Debug
180
- print(f"Response preview: {response[:100]}...") # Debug
181
-
182
- if not response:
183
- print("Empty response, using fallback") # Debug
184
- response = "أهلاً! أنا أليكس مساعد خدمة العملاء. كيف أقدر أساعدك اليوم؟"
185
-
186
  yield response
187
 
188
- except Exception as e:
189
- error_msg = f"خطأ في التوليد: {str(e)}"
190
- print(error_msg)
191
- print(f"Error type: {type(e)}") # Debug
192
- import traceback
193
- print("Traceback:")
194
- print(traceback.format_exc()) # Debug
195
-
196
- yield "��هلاً! أنا أليكس مساعد خدمة العملاء. كيف أقدر أساعدك اليوم؟"
197
-
198
  demo = gr.ChatInterface(
199
  fn=generate_response,
200
  additional_inputs=[
201
  gr.Slider(label="الحد الأقصى للكلمات الجديدة", minimum=64, maximum=4096, step=1, value=2048),
202
- gr.Slider(label="درجة الحرارة", minimum=0.1, maximum=2.0, step=0.1, value=1.0),
203
- gr.Slider(label="Top-p", minimum=0.1, maximum=1.0, step=0.05, value=0.9),
204
  gr.Slider(label="Top-k", minimum=1, maximum=100, step=1, value=50),
205
- gr.Slider(label="عقوبة التكرار", minimum=1.0, maximum=1.5, step=0.05, value=1.2)
206
  ],
207
  examples=[
208
  [{"text": "النت عندي معطل من الصبح، تقدر تساعدني؟"}],
 
15
 
16
  DEFAULT_SYSTEM_PROMPT = load_system_prompt()
17
 
18
+ model_path = "unsloth/Meta-Llama-3.1-8B-Instruct-bnb-4bit"
19
 
20
  # إذا كان فيه HF_TOKEN في البيئة
21
  hf_token = os.getenv("HF_TOKEN")
 
96
 
97
  @spaces.GPU()
98
  def generate_response(input_data, chat_history, max_new_tokens, temperature, top_p, top_k, repetition_penalty):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
  # Build messages for Llama chat template
100
  messages = [{"role": "system", "content": DEFAULT_SYSTEM_PROMPT}]
101
 
 
108
  # Add current user message
109
  messages.append({"role": "user", "content": input_data})
110
 
111
+ # استخدام ChatPipeline المخصص مع streaming
112
+ streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
113
+
114
+ generation_kwargs = pipe(
115
+ messages,
116
+ streamer=streamer,
117
+ max_new_tokens=max_new_tokens,
118
+ temperature=temperature,
119
+ top_p=top_p,
120
+ top_k=top_k,
121
+ repetition_penalty=repetition_penalty
122
+ )
123
+
124
+ thread = Thread(target=model.generate, kwargs=generation_kwargs)
125
+ thread.start()
126
+
127
+ # Stream the response
128
+ response = ""
129
+ for chunk in streamer:
130
+ response += chunk
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
131
  yield response
132
 
 
 
 
 
 
 
 
 
 
 
133
  demo = gr.ChatInterface(
134
  fn=generate_response,
135
  additional_inputs=[
136
  gr.Slider(label="الحد الأقصى للكلمات الجديدة", minimum=64, maximum=4096, step=1, value=2048),
137
+ gr.Slider(label="درجة الحرارة", minimum=0.1, maximum=2.0, step=0.1, value=0.7),
138
+ gr.Slider(label="Top-p", minimum=0.05, maximum=1.0, step=0.05, value=0.9),
139
  gr.Slider(label="Top-k", minimum=1, maximum=100, step=1, value=50),
140
+ gr.Slider(label="عقوبة التكرار", minimum=1.0, maximum=2.0, step=0.05, value=1.0)
141
  ],
142
  examples=[
143
  [{"text": "النت عندي معطل من الصبح، تقدر تساعدني؟"}],