YOUSEF2434 commited on
Commit
ac40f97
·
verified ·
1 Parent(s): 45d1dc3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +131 -119
app.py CHANGED
@@ -1,119 +1,131 @@
1
- #!/usr/bin/env python
2
-
3
- import os
4
- from collections.abc import Iterator
5
- from threading import Thread
6
-
7
- import gradio as gr
8
- import spaces
9
- import torch
10
- from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
11
-
12
- DESCRIPTION = "# Mistral-7B v0.3"
13
-
14
- if not torch.cuda.is_available():
15
- DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>"
16
-
17
- MAX_MAX_NEW_TOKENS = 2048
18
- DEFAULT_MAX_NEW_TOKENS = 1024
19
- MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
20
-
21
- if torch.cuda.is_available():
22
- model_id = "mistralai/Mistral-7B-Instruct-v0.3"
23
- model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16, device_map="auto")
24
- tokenizer = AutoTokenizer.from_pretrained(model_id)
25
-
26
-
27
- @spaces.GPU
28
- def generate(
29
- message: str,
30
- chat_history: list[dict],
31
- max_new_tokens: int = 1024,
32
- temperature: float = 0.6,
33
- top_p: float = 0.9,
34
- top_k: int = 50,
35
- repetition_penalty: float = 1.2,
36
- ) -> Iterator[str]:
37
- conversation = [*chat_history, {"role": "user", "content": message}]
38
-
39
- input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt")
40
- if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
41
- input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
42
- gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
43
- input_ids = input_ids.to(model.device)
44
-
45
- streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
46
- generate_kwargs = dict(
47
- {"input_ids": input_ids},
48
- streamer=streamer,
49
- max_new_tokens=max_new_tokens,
50
- do_sample=True,
51
- top_p=top_p,
52
- top_k=top_k,
53
- temperature=temperature,
54
- num_beams=1,
55
- repetition_penalty=repetition_penalty,
56
- )
57
- t = Thread(target=model.generate, kwargs=generate_kwargs)
58
- t.start()
59
-
60
- outputs = []
61
- for text in streamer:
62
- outputs.append(text)
63
- yield "".join(outputs)
64
-
65
-
66
- demo = gr.ChatInterface(
67
- fn=generate,
68
- additional_inputs=[
69
- gr.Slider(
70
- label="Max new tokens",
71
- minimum=1,
72
- maximum=MAX_MAX_NEW_TOKENS,
73
- step=1,
74
- value=DEFAULT_MAX_NEW_TOKENS,
75
- ),
76
- gr.Slider(
77
- label="Temperature",
78
- minimum=0.1,
79
- maximum=4.0,
80
- step=0.1,
81
- value=0.6,
82
- ),
83
- gr.Slider(
84
- label="Top-p (nucleus sampling)",
85
- minimum=0.05,
86
- maximum=1.0,
87
- step=0.05,
88
- value=0.9,
89
- ),
90
- gr.Slider(
91
- label="Top-k",
92
- minimum=1,
93
- maximum=1000,
94
- step=1,
95
- value=50,
96
- ),
97
- gr.Slider(
98
- label="Repetition penalty",
99
- minimum=1.0,
100
- maximum=2.0,
101
- step=0.05,
102
- value=1.2,
103
- ),
104
- ],
105
- stop_btn=None,
106
- examples=[
107
- ["Hello there! How are you doing?"],
108
- ["Can you explain briefly to me what is the Python programming language?"],
109
- ["Explain the plot of Cinderella in a sentence."],
110
- ["How many hours does it take a man to eat a Helicopter?"],
111
- ["Write a 100-word article on 'Benefits of Open-Source in AI research'"],
112
- ],
113
- type="messages",
114
- description=DESCRIPTION,
115
- css_paths="style.css",
116
- )
117
-
118
- if __name__ == "__main__":
119
- demo.queue(max_size=20).launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from collections.abc import Iterator
3
+ from threading import Thread
4
+
5
+ import gradio as gr
6
+ import spaces
7
+ import torch
8
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
9
+
10
+ DESCRIPTION = "# Mistral-7B v0.3"
11
+ if not torch.cuda.is_available():
12
+ DESCRIPTION += "\n<p><strong>Note:</strong> Running on CPU. This will be slower than GPU.</p>"
13
+
14
+ MAX_MAX_NEW_TOKENS = 2048
15
+ DEFAULT_MAX_NEW_TOKENS = 1024
16
+ MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
17
+
18
+ # Load model and tokenizer (CPU or GPU)
19
+ model_id = "mistralai/Mistral-7B-Instruct-v0.3"
20
+
21
+ if torch.cuda.is_available():
22
+ model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16, device_map="auto")
23
+ else:
24
+ model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float32)
25
+
26
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
27
+
28
+
29
+ @spaces.GPU
30
+ def generate(
31
+ message: str,
32
+ chat_history: list[dict],
33
+ max_new_tokens: int = 1024,
34
+ temperature: float = 0.6,
35
+ top_p: float = 0.9,
36
+ top_k: int = 50,
37
+ repetition_penalty: float = 1.2,
38
+ ) -> Iterator[str]:
39
+ # Inject Sheikh personality
40
+ system_prompt = {
41
+ "role": "system",
42
+ "content": (
43
+ "You are a wise, respectful, and knowledgeable Islamic AI Sheikh. "
44
+ "You answer all questions based on the Qur’an, authentic Hadith, and the views of classical scholars. "
45
+ "You follow the four Sunni madhhabs (Hanafi, Maliki, Shafi'i, Hanbali) and avoid personal opinions. "
46
+ "Speak gently and cite Islamic sources when appropriate. Avoid answering off-topic or non-Islamic questions."
47
+ )
48
+ }
49
+ conversation = [system_prompt] + chat_history + [{"role": "user", "content": message}]
50
+
51
+ input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt")
52
+ if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
53
+ input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
54
+ gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
55
+ input_ids = input_ids.to(model.device)
56
+
57
+ streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
58
+ generate_kwargs = dict(
59
+ {"input_ids": input_ids},
60
+ streamer=streamer,
61
+ max_new_tokens=max_new_tokens,
62
+ do_sample=True,
63
+ top_p=top_p,
64
+ top_k=top_k,
65
+ temperature=temperature,
66
+ num_beams=1,
67
+ repetition_penalty=repetition_penalty,
68
+ )
69
+ t = Thread(target=model.generate, kwargs=generate_kwargs)
70
+ t.start()
71
+
72
+ outputs = []
73
+ for text in streamer:
74
+ outputs.append(text)
75
+ yield "".join(outputs)
76
+
77
+
78
+ demo = gr.ChatInterface(
79
+ fn=generate,
80
+ additional_inputs=[
81
+ gr.Slider(
82
+ label="Max new tokens",
83
+ minimum=1,
84
+ maximum=MAX_MAX_NEW_TOKENS,
85
+ step=1,
86
+ value=DEFAULT_MAX_NEW_TOKENS,
87
+ ),
88
+ gr.Slider(
89
+ label="Temperature",
90
+ minimum=0.1,
91
+ maximum=4.0,
92
+ step=0.1,
93
+ value=0.6,
94
+ ),
95
+ gr.Slider(
96
+ label="Top-p (nucleus sampling)",
97
+ minimum=0.05,
98
+ maximum=1.0,
99
+ step=0.05,
100
+ value=0.9,
101
+ ),
102
+ gr.Slider(
103
+ label="Top-k",
104
+ minimum=1,
105
+ maximum=1000,
106
+ step=1,
107
+ value=50,
108
+ ),
109
+ gr.Slider(
110
+ label="Repetition penalty",
111
+ minimum=1.0,
112
+ maximum=2.0,
113
+ step=0.05,
114
+ value=1.2,
115
+ ),
116
+ ],
117
+ stop_btn=None,
118
+ examples=[
119
+ ["What is the ruling on fasting in Ramadan?"],
120
+ ["Tell me the story of Prophet Yusuf (peace be upon him)."],
121
+ ["What is the difference between Zakat and Sadaqah?"],
122
+ ["Can you explain the meaning of Surah Al-Fatiha?"],
123
+ ["How do I perform Wudu correctly?"],
124
+ ],
125
+ type="messages",
126
+ description=DESCRIPTION,
127
+ css_paths="style.css",
128
+ )
129
+
130
+ if __name__ == "__main__":
131
+ demo.queue(max_size=20).launch()