KGSAGAR commited on
Commit
3b5918b
·
verified ·
1 Parent(s): 9c2dbe4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -36
app.py CHANGED
@@ -1,38 +1,32 @@
1
  import gradio as gr
2
  from transformers import AutoTokenizer, AutoModelForCausalLM
3
  from peft import PeftModel
4
- from huggingface_hub import InferenceClient
5
  import re
6
  import torch
7
 
8
- """
9
- For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
10
- """
11
-
12
  # Model and tokenizer loading (outside the respond function)
13
  try:
 
14
  tokenizer = AutoTokenizer.from_pretrained("sarvamai/sarvam-1")
 
 
15
  base_model = AutoModelForCausalLM.from_pretrained("sarvamai/sarvam-1")
 
 
16
  peft_model = PeftModel.from_pretrained(base_model, "KGSAGAR/Sarvam-1-text-normalization-3r")
 
 
17
  peft_model = peft_model.merge_and_unload()
18
- print("Model loaded successfully!") # Add this line
 
19
  except Exception as e:
20
  print(f"Error loading model: {e}")
21
  tokenizer = None
22
- base_model = None
23
  peft_model = None
24
 
25
- def respond(
26
- message,
27
- history,
28
- system_message,
29
- max_tokens,
30
- temperature,
31
- top_p,
32
- ):
33
  """
34
  Generates a response based on the user message and history using the provided PEFT model.
35
-
36
  Args:
37
  message (str): The user's input message.
38
  history (list of tuples): A list containing tuples of (user_message, assistant_response).
@@ -40,11 +34,10 @@ def respond(
40
  max_tokens (int): The maximum number of tokens to generate.
41
  temperature (float): The temperature parameter for generation.
42
  top_p (float): The top_p parameter for nucleus sampling.
43
-
44
  Yields:
45
  str: The generated response up to the current token.
46
  """
47
- global tokenizer, peft_model #access global variables
48
  if tokenizer is None or peft_model is None:
49
  yield "Model loading failed. Please check the logs."
50
  return
@@ -71,9 +64,8 @@ def respond(
71
  do_sample=True # Enable sampling for more diverse outputs
72
  )
73
  except Exception as e:
74
- yield f"Generation error: {e}"
75
- return
76
-
77
 
78
  # Decode the generated tokens
79
  generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
@@ -84,7 +76,7 @@ def respond(
84
  Extracts and returns content between <user>...</user> tags in the given text.
85
  If multiple such sections exist, their contents are concatenated.
86
  """
87
- pattern = r'<user>(.*?)</user>'
88
  matches = re.findall(pattern, text, re.DOTALL)
89
  extracted_content = '\n'.join(match.strip() for match in matches)
90
  return extracted_content
@@ -98,27 +90,20 @@ def respond(
98
  response += token + " "
99
  yield response.strip()
100
 
101
-
102
-
103
- """
104
- For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
105
- """
106
  demo = gr.ChatInterface(
107
  respond,
108
  additional_inputs=[
109
- gr.Textbox(value="Take the user input in Hindi language and normalize specific entities, Only including: Dates (any format) Currencies Scientific units, <Example> Exampleinput : 2012–13 में रक्षा सेवाओं के लिए 1,93,407 करोड़ रुपए का प्रावधान किया गया था, जबकि 2011–2012 में यह राशि 1,64,415 करोइ़ थी, Exampleoutput: ट्वेन्टी ट्वेल्व फिफ्टीन में रक्षा सेवाओं के लिए वन करोड़ निनेटी थ्री थाउजेंड फोर हंड्रेड सेवन करोड़ रुपए का प्रावधान किया गया था, जबकि ट्वेन्टी एलेवन ट्वेल्व में यह राशि वन करोड़ सिक्स्टी फोर थाउजेंड फोर हंड्रेड फिफ्टीन करोड़ थी </Example>, Only provide the normalized output with atmost accuracy <user> input:", label="System message"),
 
 
 
110
  gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
111
  gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
112
- gr.Slider(
113
- minimum=0.1,
114
- maximum=1.0,
115
- value=0.95,
116
- step=0.05,
117
- label="Top-p (nucleus sampling)",
118
- ),
119
  ],
120
  )
121
 
122
-
123
  if __name__ == "__main__":
124
  demo.launch()
 
 
1
  import gradio as gr
2
  from transformers import AutoTokenizer, AutoModelForCausalLM
3
  from peft import PeftModel
 
4
  import re
5
  import torch
6
 
 
 
 
 
7
  # Model and tokenizer loading (outside the respond function)
8
  try:
9
+ # Load the tokenizer
10
  tokenizer = AutoTokenizer.from_pretrained("sarvamai/sarvam-1")
11
+
12
+ # Load the base model
13
  base_model = AutoModelForCausalLM.from_pretrained("sarvamai/sarvam-1")
14
+
15
+ # Load the PEFT model
16
  peft_model = PeftModel.from_pretrained(base_model, "KGSAGAR/Sarvam-1-text-normalization-3r")
17
+
18
+ # Merge and unload the PEFT model into the base model
19
  peft_model = peft_model.merge_and_unload()
20
+
21
+ print("Model loaded successfully!")
22
  except Exception as e:
23
  print(f"Error loading model: {e}")
24
  tokenizer = None
 
25
  peft_model = None
26
 
27
+ def respond(message, history, system_message, max_tokens, temperature, top_p):
 
 
 
 
 
 
 
28
  """
29
  Generates a response based on the user message and history using the provided PEFT model.
 
30
  Args:
31
  message (str): The user's input message.
32
  history (list of tuples): A list containing tuples of (user_message, assistant_response).
 
34
  max_tokens (int): The maximum number of tokens to generate.
35
  temperature (float): The temperature parameter for generation.
36
  top_p (float): The top_p parameter for nucleus sampling.
 
37
  Yields:
38
  str: The generated response up to the current token.
39
  """
40
+ global tokenizer, peft_model # Access global variables
41
  if tokenizer is None or peft_model is None:
42
  yield "Model loading failed. Please check the logs."
43
  return
 
64
  do_sample=True # Enable sampling for more diverse outputs
65
  )
66
  except Exception as e:
67
+ yield f"Generation error: {e}"
68
+ return
 
69
 
70
  # Decode the generated tokens
71
  generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
 
76
  Extracts and returns content between <user>...</user> tags in the given text.
77
  If multiple such sections exist, their contents are concatenated.
78
  """
79
+ pattern = re.compile(r'<user>(.*?)</user>|output:', re.IGNORECASE)
80
  matches = re.findall(pattern, text, re.DOTALL)
81
  extracted_content = '\n'.join(match.strip() for match in matches)
82
  return extracted_content
 
90
  response += token + " "
91
  yield response.strip()
92
 
93
+ # Gradio interface setup
 
 
 
 
94
  demo = gr.ChatInterface(
95
  respond,
96
  additional_inputs=[
97
+ gr.Textbox(
98
+ value="Take the user input in Hindi language and normalize specific entities, including: Dates (any format), Currencies, Scientific units. Example input: 2012–13 में रक्षा सेवाओं के लिए 1,93,407 करोड़ रुपए का प्रावधान किया गया था, जबकि 2011–2012 में यह राशि 1,64,415 करोड़ थी. Example output: ट्वेन्टी ट्वेल्व थर्टीन में रक्षा सेवाओं के लि�� एक लाख तिरानवे हजार चार सौ सात करोड़ रुपए का प्रावधान किया गया था, जबकि ट्वेन्टी इलेवन ट्वेल्व में यह राशि एक लाख चौसठ हजार चार सौ पंद्रह करोड़ थी. Only provide the normalized output with utmost accuracy.",
99
+ label="System message"
100
+ ),
101
  gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
102
  gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
103
+ gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)"),
 
 
 
 
 
 
104
  ],
105
  )
106
 
 
107
  if __name__ == "__main__":
108
  demo.launch()
109
+