anaspro commited on
Commit
ec237e7
·
verified ·
1 Parent(s): ae2df6e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +59 -20
app.py CHANGED
@@ -1,8 +1,31 @@
1
  import gradio as gr
2
- from huggingface_hub import InferenceClient
 
 
3
  import spaces
4
- import os
5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  @spaces.GPU
7
  def respond(
8
  message,
@@ -12,36 +35,52 @@ def respond(
12
  temperature,
13
  top_p,
14
  ):
15
- # استخدم Token من Secrets
16
- token = os.environ.get("HF_TOKEN")
17
- client = InferenceClient(model="anaspro/iraqi-kashif-2b", token=token)
18
-
19
  messages = [{"role": "system", "content": system_message}]
20
  messages.extend(history)
21
  messages.append({"role": "user", "content": message})
22
-
23
- response = ""
24
- for msg in client.chat_completion(
25
  messages,
26
- max_tokens=max_tokens,
27
- stream=True,
 
 
 
 
 
 
 
 
 
 
28
  temperature=temperature,
29
  top_p=top_p,
30
- ):
31
- if msg.choices and msg.choices[0].delta.content:
32
- response += msg.choices[0].delta.content
33
- yield response
 
 
 
 
 
 
 
 
34
 
 
35
  chatbot = gr.ChatInterface(
36
- respond,
37
  type="messages",
38
  additional_inputs=[
39
- gr.Textbox(value="أنت مساعد عراقي ذكي.", label="System message"),
40
- gr.Slider(minimum=1, maximum=512, value=100, step=1, label="Max tokens"),
41
  gr.Slider(minimum=0.1, maximum=2.0, value=0.7, step=0.1, label="Temperature"),
42
- gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p"),
43
  ],
44
  )
45
 
46
  if __name__ == "__main__":
47
- chatbot.launch()
 
1
  import gradio as gr
2
+ import torch
3
+ from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
4
+ from threading import Thread
5
  import spaces
 
6
 
7
+ # ✅ Use GPU if available
8
+ device = "cuda" if torch.cuda.is_available() else "cpu"
9
+
10
+ # ✅ Load your model and tokenizer
11
+ MODEL_NAME = "anaspro/iraqi-kashif-2b"
12
+
13
+ @spaces.GPU
14
+ def load_model():
15
+ print("🔄 Loading model and tokenizer...")
16
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
17
+ model = AutoModelForCausalLM.from_pretrained(
18
+ MODEL_NAME,
19
+ torch_dtype=torch.float16,
20
+ device_map="auto",
21
+ )
22
+ model.eval()
23
+ print("✅ Model loaded successfully!")
24
+ return tokenizer, model
25
+
26
+ tokenizer, model = load_model()
27
+
28
+ # ✅ Respond function using streaming
29
  @spaces.GPU
30
  def respond(
31
  message,
 
35
  temperature,
36
  top_p,
37
  ):
38
+ # Combine chat history and user message into a single prompt
 
 
 
39
  messages = [{"role": "system", "content": system_message}]
40
  messages.extend(history)
41
  messages.append({"role": "user", "content": message})
42
+
43
+ # Apply chat template (your repo has chat_template.jinja)
44
+ prompt = tokenizer.apply_chat_template(
45
  messages,
46
+ tokenize=False,
47
+ add_generation_prompt=True,
48
+ )
49
+
50
+ # Prepare streamer for live token generation
51
+ streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
52
+ inputs = tokenizer(prompt, return_tensors="pt").to(device)
53
+
54
+ generation_kwargs = dict(
55
+ **inputs,
56
+ streamer=streamer,
57
+ max_new_tokens=max_tokens,
58
  temperature=temperature,
59
  top_p=top_p,
60
+ do_sample=True,
61
+ )
62
+
63
+ # Run generation in background thread
64
+ thread = Thread(target=model.generate, kwargs=generation_kwargs)
65
+ thread.start()
66
+
67
+ # Stream tokens as they arrive
68
+ response = ""
69
+ for new_text in streamer:
70
+ response += new_text
71
+ yield response
72
 
73
+ # ✅ Gradio chat UI
74
  chatbot = gr.ChatInterface(
75
+ fn=respond,
76
  type="messages",
77
  additional_inputs=[
78
+ gr.Textbox(value="أنت مساعد ذكي تتحدث باللهجة العراقية.", label="System message"),
79
+ gr.Slider(minimum=32, maximum=512, value=128, step=8, label="Max tokens"),
80
  gr.Slider(minimum=0.1, maximum=2.0, value=0.7, step=0.1, label="Temperature"),
81
+ gr.Slider(minimum=0.1, maximum=1.0, value=0.9, step=0.05, label="Top-p"),
82
  ],
83
  )
84
 
85
  if __name__ == "__main__":
86
+ chatbot.launch(server_name="0.0.0.0", server_port=7860)