KGSAGAR commited on
Commit
896a3c2
·
verified ·
1 Parent(s): 6cb1f28

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +44 -24
app.py CHANGED
@@ -1,30 +1,35 @@
1
  import gradio as gr
2
  from transformers import AutoTokenizer, AutoModelForCausalLM
3
  from peft import PeftModel
 
4
  import re
5
  import torch
6
 
 
 
 
 
7
  # Model and tokenizer loading (outside the respond function)
8
  try:
9
- # Load the tokenizer
10
  tokenizer = AutoTokenizer.from_pretrained("sarvamai/sarvam-1")
11
-
12
- # Load the base model
13
  base_model = AutoModelForCausalLM.from_pretrained("sarvamai/sarvam-1")
14
-
15
- # Load the PEFT model
16
  peft_model = PeftModel.from_pretrained(base_model, "KGSAGAR/Sarvam-1-text-normalization-3r")
17
-
18
- # Merge and unload the PEFT model into the base model
19
  peft_model = peft_model.merge_and_unload()
20
-
21
- print("Model loaded successfully!")
22
  except Exception as e:
23
  print(f"Error loading model: {e}")
24
  tokenizer = None
 
25
  peft_model = None
26
 
27
- def respond(message, history, system_message, max_tokens, temperature, top_p):
 
 
 
 
 
 
 
28
  """
29
  Generates a response based on the user message and history using the provided PEFT model.
30
  Args:
@@ -34,12 +39,13 @@ def respond(message, history, system_message, max_tokens, temperature, top_p):
34
  max_tokens (int): The maximum number of tokens to generate.
35
  temperature (float): The temperature parameter for generation.
36
  top_p (float): The top_p parameter for nucleus sampling.
37
- Returns:
38
- str: The generated response.
39
  """
40
- global tokenizer, peft_model # Access global variables
41
  if tokenizer is None or peft_model is None:
42
- return "Model loading failed. Please check the logs."
 
43
 
44
  # Construct the prompt
45
  prompt = system_message
@@ -63,7 +69,9 @@ def respond(message, history, system_message, max_tokens, temperature, top_p):
63
  do_sample=True # Enable sampling for more diverse outputs
64
  )
65
  except Exception as e:
66
- return f"Generation error: {e}"
 
 
67
 
68
  # Decode the generated tokens
69
  generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
@@ -74,29 +82,41 @@ def respond(message, history, system_message, max_tokens, temperature, top_p):
74
  Extracts and returns content between <user>...</user> tags in the given text.
75
  If multiple such sections exist, their contents are concatenated.
76
  """
77
- pattern = re.compile(r'<user>(.*?)</user>|output:', re.IGNORECASE)
78
  matches = re.findall(pattern, text, re.DOTALL)
79
- extracted_content = '\n'.join(match.strip() for match in matches if match)
80
  return extracted_content
81
 
82
  # Extract the normalized text
83
  normalized_text = extract_user_content(generated_text)
84
 
85
- return normalized_text
 
 
 
 
 
 
86
 
87
- # Gradio interface setup
 
 
88
  demo = gr.ChatInterface(
89
  respond,
90
  additional_inputs=[
91
- gr.Textbox(
92
- value="Take the user input in Hindi language and normalize specific entities, including: Dates (any format), Currencies, Scientific units. Example input: 2012–13 में रक्षा सेवाओं के लिए 1,93,407 करोड़ रुपए का प्रावधान किया गया था, जबकि 2011–2012 में यह राशि 1,64,415 करोड़ थी. Example output: ट्वेन्टी ट्वेल्व थर्टीन में रक्षा सेवाओं के लिए एक लाख तिरानवे हजार चार सौ सात करोड़ रुपए का प्रावधान किया गया था, जबकि ट्वेन्टी इलेवन ट्वेल्व में यह राशि एक लाख चौसठ हजार चार सौ पंद्रह करोड़ थी. Only provide the normalized output with utmost accuracy.",
93
- label="System message"
94
- ),
95
  gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
96
  gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
97
- gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)"),
 
 
 
 
 
 
98
  ],
99
  )
100
 
 
101
  if __name__ == "__main__":
102
  demo.launch()
 
1
  import gradio as gr
2
  from transformers import AutoTokenizer, AutoModelForCausalLM
3
  from peft import PeftModel
4
+ from huggingface_hub import InferenceClient
5
  import re
6
  import torch
7
 
8
+ """
9
+ For more information on huggingface_hub Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
10
+ """
11
+
12
  # Model and tokenizer loading (outside the respond function)
13
  try:
 
14
  tokenizer = AutoTokenizer.from_pretrained("sarvamai/sarvam-1")
 
 
15
  base_model = AutoModelForCausalLM.from_pretrained("sarvamai/sarvam-1")
 
 
16
  peft_model = PeftModel.from_pretrained(base_model, "KGSAGAR/Sarvam-1-text-normalization-3r")
 
 
17
  peft_model = peft_model.merge_and_unload()
18
+ print("Model loaded successfully!") # Add this line
 
19
  except Exception as e:
20
  print(f"Error loading model: {e}")
21
  tokenizer = None
22
+ base_model = None
23
  peft_model = None
24
 
25
+ def respond(
26
+ message,
27
+ history,
28
+ system_message,
29
+ max_tokens,
30
+ temperature,
31
+ top_p,
32
+ ):
33
  """
34
  Generates a response based on the user message and history using the provided PEFT model.
35
  Args:
 
39
  max_tokens (int): The maximum number of tokens to generate.
40
  temperature (float): The temperature parameter for generation.
41
  top_p (float): The top_p parameter for nucleus sampling.
42
+ Yields:
43
+ str: The generated response up to the current token.
44
  """
45
+ global tokenizer, peft_model #access global variables
46
  if tokenizer is None or peft_model is None:
47
+ yield "Model loading failed. Please check the logs."
48
+ return
49
 
50
  # Construct the prompt
51
  prompt = system_message
 
69
  do_sample=True # Enable sampling for more diverse outputs
70
  )
71
  except Exception as e:
72
+ yield f"Generation error: {e}"
73
+ return
74
+
75
 
76
  # Decode the generated tokens
77
  generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
 
82
  Extracts and returns content between <user>...</user> tags in the given text.
83
  If multiple such sections exist, their contents are concatenated.
84
  """
85
+ pattern = r'<user>(.*?)</user>'
86
  matches = re.findall(pattern, text, re.DOTALL)
87
+ extracted_content = '\n'.join(match.strip() for match in matches)
88
  return extracted_content
89
 
90
  # Extract the normalized text
91
  normalized_text = extract_user_content(generated_text)
92
 
93
+ # Stream the response token by token
94
+ response = ""
95
+ for token in normalized_text.split():
96
+ response += token + " "
97
+ yield response.strip()
98
+
99
+
100
 
101
+ """
102
+ For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
103
+ """
104
  demo = gr.ChatInterface(
105
  respond,
106
  additional_inputs=[
107
+ gr.Textbox(value="Take the user input in Hindi language and normalize specific entities, Only including: Dates (any format) Currencies Scientific units, <Example> Exampleinput : 2012–13 मे�� रक्षा सेवाओं के लिए 1,93,407 करोड़ रुपए का प्रावधान किया गया था, जबकि 2011–2012 में यह राशि 1,64,415 करोइ़ थी, Exampleoutput: ट्वेन्टी ट्वेल्व फिफ्टीन में रक्षा सेवाओं के लिए वन करोड़ निनेटी थ्री थाउजेंड फोर हंड्रेड सेवन करोड़ रुपए का प्रावधान किया गया था, जबकि ट्वेन्टी एलेवन ट्वेल्व में यह राशि वन करोड़ सिक्स्टी फोर थाउजेंड फोर हंड्रेड फिफ्टीन करोड़ थी </Example>, Only provide the normalized output with atmost accuracy <user> input:", label="System message"),
 
 
 
108
  gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
109
  gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
110
+ gr.Slider(
111
+ minimum=0.1,
112
+ maximum=1.0,
113
+ value=0.95,
114
+ step=0.05,
115
+ label="Top-p (nucleus sampling)",
116
+ ),
117
  ],
118
  )
119
 
120
+
121
  if __name__ == "__main__":
122
  demo.launch()