Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,30 +1,35 @@
|
|
| 1 |
import gradio as gr
|
| 2 |
from transformers import AutoTokenizer, AutoModelForCausalLM
|
| 3 |
from peft import PeftModel
|
|
|
|
| 4 |
import re
|
| 5 |
import torch
|
| 6 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
# Model and tokenizer loading (outside the respond function)
|
| 8 |
try:
|
| 9 |
-
# Load the tokenizer
|
| 10 |
tokenizer = AutoTokenizer.from_pretrained("sarvamai/sarvam-1")
|
| 11 |
-
|
| 12 |
-
# Load the base model
|
| 13 |
base_model = AutoModelForCausalLM.from_pretrained("sarvamai/sarvam-1")
|
| 14 |
-
|
| 15 |
-
# Load the PEFT model
|
| 16 |
peft_model = PeftModel.from_pretrained(base_model, "KGSAGAR/Sarvam-1-text-normalization-3r")
|
| 17 |
-
|
| 18 |
-
# Merge and unload the PEFT model into the base model
|
| 19 |
peft_model = peft_model.merge_and_unload()
|
| 20 |
-
|
| 21 |
-
print("Model loaded successfully!")
|
| 22 |
except Exception as e:
|
| 23 |
print(f"Error loading model: {e}")
|
| 24 |
tokenizer = None
|
|
|
|
| 25 |
peft_model = None
|
| 26 |
|
| 27 |
-
def respond(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
"""
|
| 29 |
Generates a response based on the user message and history using the provided PEFT model.
|
| 30 |
Args:
|
|
@@ -34,12 +39,13 @@ def respond(message, history, system_message, max_tokens, temperature, top_p):
|
|
| 34 |
max_tokens (int): The maximum number of tokens to generate.
|
| 35 |
temperature (float): The temperature parameter for generation.
|
| 36 |
top_p (float): The top_p parameter for nucleus sampling.
|
| 37 |
-
|
| 38 |
-
str: The generated response.
|
| 39 |
"""
|
| 40 |
-
global tokenizer, peft_model
|
| 41 |
if tokenizer is None or peft_model is None:
|
| 42 |
-
|
|
|
|
| 43 |
|
| 44 |
# Construct the prompt
|
| 45 |
prompt = system_message
|
|
@@ -63,7 +69,9 @@ def respond(message, history, system_message, max_tokens, temperature, top_p):
|
|
| 63 |
do_sample=True # Enable sampling for more diverse outputs
|
| 64 |
)
|
| 65 |
except Exception as e:
|
| 66 |
-
|
|
|
|
|
|
|
| 67 |
|
| 68 |
# Decode the generated tokens
|
| 69 |
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
|
@@ -74,29 +82,41 @@ def respond(message, history, system_message, max_tokens, temperature, top_p):
|
|
| 74 |
Extracts and returns content between <user>...</user> tags in the given text.
|
| 75 |
If multiple such sections exist, their contents are concatenated.
|
| 76 |
"""
|
| 77 |
-
pattern =
|
| 78 |
matches = re.findall(pattern, text, re.DOTALL)
|
| 79 |
-
extracted_content = '\n'.join(match.strip() for match in matches
|
| 80 |
return extracted_content
|
| 81 |
|
| 82 |
# Extract the normalized text
|
| 83 |
normalized_text = extract_user_content(generated_text)
|
| 84 |
|
| 85 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 86 |
|
| 87 |
-
|
|
|
|
|
|
|
| 88 |
demo = gr.ChatInterface(
|
| 89 |
respond,
|
| 90 |
additional_inputs=[
|
| 91 |
-
gr.Textbox(
|
| 92 |
-
value="Take the user input in Hindi language and normalize specific entities, including: Dates (any format), Currencies, Scientific units. Example input: 2012–13 में रक्षा सेवाओं के लिए 1,93,407 करोड़ रुपए का प्रावधान किया गया था, जबकि 2011–2012 में यह राशि 1,64,415 करोड़ थी. Example output: ट्वेन्टी ट्वेल्व थर्टीन में रक्षा सेवाओं के लिए एक लाख तिरानवे हजार चार सौ सात करोड़ रुपए का प्रावधान किया गया था, जबकि ट्वेन्टी इलेवन ट्वेल्व में यह राशि एक लाख चौसठ हजार चार सौ पंद्रह करोड़ थी. Only provide the normalized output with utmost accuracy.",
|
| 93 |
-
label="System message"
|
| 94 |
-
),
|
| 95 |
gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
|
| 96 |
gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
|
| 97 |
-
gr.Slider(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 98 |
],
|
| 99 |
)
|
| 100 |
|
|
|
|
| 101 |
if __name__ == "__main__":
|
| 102 |
demo.launch()
|
|
|
|
| 1 |
import gradio as gr
|
| 2 |
from transformers import AutoTokenizer, AutoModelForCausalLM
|
| 3 |
from peft import PeftModel
|
| 4 |
+
from huggingface_hub import InferenceClient
|
| 5 |
import re
|
| 6 |
import torch
|
| 7 |
|
| 8 |
+
"""
|
| 9 |
+
For more information on huggingface_hub Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
# Model and tokenizer loading (outside the respond function)
|
| 13 |
try:
|
|
|
|
| 14 |
tokenizer = AutoTokenizer.from_pretrained("sarvamai/sarvam-1")
|
|
|
|
|
|
|
| 15 |
base_model = AutoModelForCausalLM.from_pretrained("sarvamai/sarvam-1")
|
|
|
|
|
|
|
| 16 |
peft_model = PeftModel.from_pretrained(base_model, "KGSAGAR/Sarvam-1-text-normalization-3r")
|
|
|
|
|
|
|
| 17 |
peft_model = peft_model.merge_and_unload()
|
| 18 |
+
print("Model loaded successfully!") # Add this line
|
|
|
|
| 19 |
except Exception as e:
|
| 20 |
print(f"Error loading model: {e}")
|
| 21 |
tokenizer = None
|
| 22 |
+
base_model = None
|
| 23 |
peft_model = None
|
| 24 |
|
| 25 |
+
def respond(
|
| 26 |
+
message,
|
| 27 |
+
history,
|
| 28 |
+
system_message,
|
| 29 |
+
max_tokens,
|
| 30 |
+
temperature,
|
| 31 |
+
top_p,
|
| 32 |
+
):
|
| 33 |
"""
|
| 34 |
Generates a response based on the user message and history using the provided PEFT model.
|
| 35 |
Args:
|
|
|
|
| 39 |
max_tokens (int): The maximum number of tokens to generate.
|
| 40 |
temperature (float): The temperature parameter for generation.
|
| 41 |
top_p (float): The top_p parameter for nucleus sampling.
|
| 42 |
+
Yields:
|
| 43 |
+
str: The generated response up to the current token.
|
| 44 |
"""
|
| 45 |
+
global tokenizer, peft_model #access global variables
|
| 46 |
if tokenizer is None or peft_model is None:
|
| 47 |
+
yield "Model loading failed. Please check the logs."
|
| 48 |
+
return
|
| 49 |
|
| 50 |
# Construct the prompt
|
| 51 |
prompt = system_message
|
|
|
|
| 69 |
do_sample=True # Enable sampling for more diverse outputs
|
| 70 |
)
|
| 71 |
except Exception as e:
|
| 72 |
+
yield f"Generation error: {e}"
|
| 73 |
+
return
|
| 74 |
+
|
| 75 |
|
| 76 |
# Decode the generated tokens
|
| 77 |
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
|
|
|
| 82 |
Extracts and returns content between <user>...</user> tags in the given text.
|
| 83 |
If multiple such sections exist, their contents are concatenated.
|
| 84 |
"""
|
| 85 |
+
pattern = r'<user>(.*?)</user>'
|
| 86 |
matches = re.findall(pattern, text, re.DOTALL)
|
| 87 |
+
extracted_content = '\n'.join(match.strip() for match in matches)
|
| 88 |
return extracted_content
|
| 89 |
|
| 90 |
# Extract the normalized text
|
| 91 |
normalized_text = extract_user_content(generated_text)
|
| 92 |
|
| 93 |
+
# Stream the response token by token
|
| 94 |
+
response = ""
|
| 95 |
+
for token in normalized_text.split():
|
| 96 |
+
response += token + " "
|
| 97 |
+
yield response.strip()
|
| 98 |
+
|
| 99 |
+
|
| 100 |
|
| 101 |
+
"""
|
| 102 |
+
For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
|
| 103 |
+
"""
|
| 104 |
demo = gr.ChatInterface(
|
| 105 |
respond,
|
| 106 |
additional_inputs=[
|
| 107 |
+
gr.Textbox(value="Take the user input in Hindi language and normalize specific entities, Only including: Dates (any format) Currencies Scientific units, <Example> Exampleinput : 2012–13 मे�� रक्षा सेवाओं के लिए 1,93,407 करोड़ रुपए का प्रावधान किया गया था, जबकि 2011–2012 में यह राशि 1,64,415 करोइ़ थी, Exampleoutput: ट्वेन्टी ट्वेल्व फिफ्टीन में रक्षा सेवाओं के लिए वन करोड़ निनेटी थ्री थाउजेंड फोर हंड्रेड सेवन करोड़ रुपए का प्रावधान किया गया था, जबकि ट्वेन्टी एलेवन ट्वेल्व में यह राशि वन करोड़ सिक्स्टी फोर थाउजेंड फोर हंड्रेड फिफ्टीन करोड़ थी </Example>, Only provide the normalized output with atmost accuracy <user> input:", label="System message"),
|
|
|
|
|
|
|
|
|
|
| 108 |
gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
|
| 109 |
gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
|
| 110 |
+
gr.Slider(
|
| 111 |
+
minimum=0.1,
|
| 112 |
+
maximum=1.0,
|
| 113 |
+
value=0.95,
|
| 114 |
+
step=0.05,
|
| 115 |
+
label="Top-p (nucleus sampling)",
|
| 116 |
+
),
|
| 117 |
],
|
| 118 |
)
|
| 119 |
|
| 120 |
+
|
| 121 |
if __name__ == "__main__":
|
| 122 |
demo.launch()
|