Spaces:
Sleeping
Sleeping
File size: 3,971 Bytes
8e0ab8f 5534d9e 896a3c2 9c2dbe4 8e0ab8f 9c2dbe4 c5ecffb 9c2dbe4 896a3c2 9c2dbe4 e5d4b35 058b8bc c5ecffb 058b8bc 9c2dbe4 c5ecffb 058b8bc e5d4b35 058b8bc e5d4b35 058b8bc e5d4b35 ed49a22 058b8bc 9c2dbe4 6cb1f28 9c2dbe4 c5ecffb 9c2dbe4 058b8bc c72585a 4cb6754 35dda4d 4cb6754 c72585a 35dda4d 510e9d1 058b8bc 510e9d1 058b8bc 9c2dbe4 c5ecffb e5d4b35 8e0ab8f 35dda4d feb6b5f 8e0ab8f 058b8bc 8e0ab8f c5ecffb | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 | import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel
from huggingface_hub import InferenceClient
import re
import torch
# Model and tokenizer loading (outside the respond function)
try:
tokenizer = AutoTokenizer.from_pretrained("sarvamai/sarvam-1")
base_model = AutoModelForCausalLM.from_pretrained("sarvamai/sarvam-1")
peft_model = PeftModel.from_pretrained(base_model, "KGSAGAR/Sarvam-1-text-normalization-3r")
peft_model = peft_model.merge_and_unload()
print("Model loaded successfully!")
except Exception as e:
print(f"Error loading model: {e}")
tokenizer = None
base_model = None
peft_model = None
def respond(message, history, system_message, max_tokens, temperature, top_p):
global tokenizer, peft_model
if tokenizer is None or peft_model is None:
return "Model loading failed. Please check the logs."
# Construct the prompt
prompt = system_message
for user_msg, assistant_msg in history:
if user_msg:
prompt += f"\nUser: {user_msg}"
if assistant_msg:
prompt += f"\nAssistant: {assistant_msg}"
prompt += f"\nUser: {message}"
# Tokenize the input prompt
inputs = tokenizer(prompt, return_tensors="pt", truncation=True)
try:
outputs = peft_model.generate(
**inputs,
max_new_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
do_sample=True
)
# Get the last message from the generated text
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
def extract_user_content(text):
"""
Extracts and returns content between <user>...</user> tags in the given text.
If multiple such sections exist, their contents are concatenated.
"""
# Compile the pattern with the desired flags
pattern = re.compile(r'<user>(.*?)</user>', re.IGNORECASE | re.DOTALL)
matches = re.findall(pattern, text)
extracted_content = '\n'.join(match.strip() for match in matches)
return extracted_content
# lines = extract_user_content(generated_text)
return generated_text
except Exception as e:
return f"Generation error: {e}"
demo = gr.ChatInterface(
respond,
additional_inputs=[
gr.Textbox(value="Take the user input in Hindi language and normalize specific entities, Only including: Dates (any format) Currencies Scientific units, <Example> Exampleinput : 2012–13 में रक्षा सेवाओं के लिए 1,93,407 करोड़ रुपए का प्रावधान किया गया था, जबकि 2011–2012 में यह राशि 1,64,415 करोइ़ थी, Exampleoutput: ट्वेन्टी ट्वेल्व फिफ्टीन में रक्षा सेवाओं के लिए वन करोड़ निनेटी थ्री थाउजेंड फोर हंड्रेड सेवन करोड़ रुपए का प्रावधान किया गया था, जबकि ट्वेन्टी एलेवन ट्वेल्व में यह राशि वन करोड़ सिक्स्टी फोर थाउजेंड फोर हंड्रेड फिफ्टीन करोड़ थी </Example>, Only provide the normalized output with atmost accuracy inside <user> xml tag",label="System message"),
gr.Slider(minimum=1, maximum=2048, value=128, step=1, label="Max new tokens"),
gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)")
],
)
if __name__ == "__main__":
demo.launch() |