KGSAGAR commited on
Commit
c5ecffb
·
verified ·
1 Parent(s): 896a3c2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -31
app.py CHANGED
@@ -5,17 +5,13 @@ from huggingface_hub import InferenceClient
5
  import re
6
  import torch
7
 
8
- """
9
- For more information on huggingface_hub Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
10
- """
11
-
12
  # Model and tokenizer loading (outside the respond function)
13
  try:
14
  tokenizer = AutoTokenizer.from_pretrained("sarvamai/sarvam-1")
15
  base_model = AutoModelForCausalLM.from_pretrained("sarvamai/sarvam-1")
16
  peft_model = PeftModel.from_pretrained(base_model, "KGSAGAR/Sarvam-1-text-normalization-3r")
17
  peft_model = peft_model.merge_and_unload()
18
- print("Model loaded successfully!") # Add this line
19
  except Exception as e:
20
  print(f"Error loading model: {e}")
21
  tokenizer = None
@@ -31,7 +27,7 @@ def respond(
31
  top_p,
32
  ):
33
  """
34
- Generates a response based on the user message and history using the provided PEFT model.
35
  Args:
36
  message (str): The user's input message.
37
  history (list of tuples): A list containing tuples of (user_message, assistant_response).
@@ -39,13 +35,12 @@ def respond(
39
  max_tokens (int): The maximum number of tokens to generate.
40
  temperature (float): The temperature parameter for generation.
41
  top_p (float): The top_p parameter for nucleus sampling.
42
- Yields:
43
- str: The generated response up to the current token.
44
  """
45
- global tokenizer, peft_model #access global variables
46
  if tokenizer is None or peft_model is None:
47
- yield "Model loading failed. Please check the logs."
48
- return
49
 
50
  # Construct the prompt
51
  prompt = system_message
@@ -59,19 +54,17 @@ def respond(
59
  # Tokenize the input prompt
60
  inputs = tokenizer(prompt, return_tensors="pt", truncation=True)
61
 
62
- # Generate the output
63
  try:
64
  outputs = peft_model.generate(
65
  **inputs,
66
  max_new_tokens=max_tokens,
67
  temperature=temperature,
68
  top_p=top_p,
69
- do_sample=True # Enable sampling for more diverse outputs
70
  )
71
  except Exception as e:
72
- yield f"Generation error: {e}"
73
- return
74
-
75
 
76
  # Decode the generated tokens
77
  generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
@@ -87,24 +80,17 @@ def respond(
87
  extracted_content = '\n'.join(match.strip() for match in matches)
88
  return extracted_content
89
 
90
- # Extract the normalized text
91
  normalized_text = extract_user_content(generated_text)
 
92
 
93
- # Stream the response token by token
94
- response = ""
95
- for token in normalized_text.split():
96
- response += token + " "
97
- yield response.strip()
98
-
99
-
100
-
101
- """
102
- For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
103
- """
104
  demo = gr.ChatInterface(
105
  respond,
106
  additional_inputs=[
107
- gr.Textbox(value="Take the user input in Hindi language and normalize specific entities, Only including: Dates (any format) Currencies Scientific units, <Example> Exampleinput : 2012–13 में रक्षा सेवाओं के लिए 1,93,407 करोड़ रुपए का प्रावधान किया गया था, जबकि 2011–2012 में यह राशि 1,64,415 करोइ़ थी, Exampleoutput: ट्वेन्टी ट्वेल्व फिफ्टीन में रक्षा सेवाओं के लिए वन करोड़ निनेटी थ्री थाउजेंड फोर हंड्रेड सेवन करोड़ रुपए का प्रावधान किया गया था, जबकि ट्वेन्टी एलेवन ट्वेल्व में यह राशि वन करोड़ सिक्स्टी फोर थाउजेंड फोर हंड्रेड फिफ्टीन करोड़ थी </Example>, Only provide the normalized output with atmost accuracy <user> input:", label="System message"),
 
 
 
108
  gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
109
  gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
110
  gr.Slider(
@@ -117,6 +103,5 @@ demo = gr.ChatInterface(
117
  ],
118
  )
119
 
120
-
121
  if __name__ == "__main__":
122
- demo.launch()
 
5
  import re
6
  import torch
7
 
 
 
 
 
8
  # Model and tokenizer loading (outside the respond function)
9
  try:
10
  tokenizer = AutoTokenizer.from_pretrained("sarvamai/sarvam-1")
11
  base_model = AutoModelForCausalLM.from_pretrained("sarvamai/sarvam-1")
12
  peft_model = PeftModel.from_pretrained(base_model, "KGSAGAR/Sarvam-1-text-normalization-3r")
13
  peft_model = peft_model.merge_and_unload()
14
+ print("Model loaded successfully!")
15
  except Exception as e:
16
  print(f"Error loading model: {e}")
17
  tokenizer = None
 
27
  top_p,
28
  ):
29
  """
30
+ Generates a complete response based on the user message and history using the provided PEFT model.
31
  Args:
32
  message (str): The user's input message.
33
  history (list of tuples): A list containing tuples of (user_message, assistant_response).
 
35
  max_tokens (int): The maximum number of tokens to generate.
36
  temperature (float): The temperature parameter for generation.
37
  top_p (float): The top_p parameter for nucleus sampling.
38
+ Returns:
39
+ str: The complete generated response.
40
  """
41
+ global tokenizer, peft_model
42
  if tokenizer is None or peft_model is None:
43
+ return "Model loading failed. Please check the logs."
 
44
 
45
  # Construct the prompt
46
  prompt = system_message
 
54
  # Tokenize the input prompt
55
  inputs = tokenizer(prompt, return_tensors="pt", truncation=True)
56
 
57
+ # Generate the complete output
58
  try:
59
  outputs = peft_model.generate(
60
  **inputs,
61
  max_new_tokens=max_tokens,
62
  temperature=temperature,
63
  top_p=top_p,
64
+ do_sample=True
65
  )
66
  except Exception as e:
67
+ return f"Generation error: {e}"
 
 
68
 
69
  # Decode the generated tokens
70
  generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
 
80
  extracted_content = '\n'.join(match.strip() for match in matches)
81
  return extracted_content
82
 
83
+ # Extract and return the complete normalized text
84
  normalized_text = extract_user_content(generated_text)
85
+ return normalized_text.strip()
86
 
 
 
 
 
 
 
 
 
 
 
 
87
  demo = gr.ChatInterface(
88
  respond,
89
  additional_inputs=[
90
+ gr.Textbox(
91
+ value="Take the user input in Hindi language and normalize specific entities, Only including: Dates (any format) Currencies Scientific units, <Example> Exampleinput : 2012–13 में रक्षा सेवाओं के लिए 1,93,407 करोड़ रुपए का प्रावधान किया गया था, जबकि 2011–2012 में यह राशि 1,64,415 करोइ़ थी, Exampleoutput: ट्वेन्टी ट्वेल्व फिफ्टीन में ���क्षा सेवाओं के लिए वन करोड़ निनेटी थ्री थाउजेंड फोर हंड्रेड सेवन करोड़ रुपए का प्रावधान किया गया था, जबकि ट्वेन्टी एलेवन ट्वेल्व में यह राशि वन करोड़ सिक्स्टी फोर थाउजेंड फोर हंड्रेड फिफ्टीन करोड़ थी </Example>, Only provide the normalized output with atmost accuracy <user> input:",
92
+ label="System message"
93
+ ),
94
  gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
95
  gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
96
  gr.Slider(
 
103
  ],
104
  )
105
 
 
106
  if __name__ == "__main__":
107
+ demo.launch()