File size: 3,971 Bytes
8e0ab8f
5534d9e
 
896a3c2
9c2dbe4
 
8e0ab8f
9c2dbe4
 
 
 
 
 
c5ecffb
9c2dbe4
 
 
896a3c2
9c2dbe4
e5d4b35
058b8bc
c5ecffb
058b8bc
9c2dbe4
c5ecffb
058b8bc
e5d4b35
 
 
 
058b8bc
e5d4b35
058b8bc
 
 
e5d4b35
ed49a22
058b8bc
9c2dbe4
6cb1f28
9c2dbe4
 
 
 
c5ecffb
9c2dbe4
058b8bc
 
 
c72585a
 
 
 
 
4cb6754
35dda4d
4cb6754
c72585a
 
35dda4d
510e9d1
058b8bc
510e9d1
058b8bc
9c2dbe4
c5ecffb
e5d4b35
8e0ab8f
 
 
35dda4d
feb6b5f
8e0ab8f
058b8bc
8e0ab8f
 
 
 
c5ecffb
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel
from huggingface_hub import InferenceClient
import re
import torch

# Model and tokenizer loading (outside the respond function)
try:
    tokenizer = AutoTokenizer.from_pretrained("sarvamai/sarvam-1")
    base_model = AutoModelForCausalLM.from_pretrained("sarvamai/sarvam-1")
    peft_model = PeftModel.from_pretrained(base_model, "KGSAGAR/Sarvam-1-text-normalization-3r")
    peft_model = peft_model.merge_and_unload()
    print("Model loaded successfully!")
except Exception as e:
    print(f"Error loading model: {e}")
    tokenizer = None
    base_model = None
    peft_model = None

def respond(message, history, system_message, max_tokens, temperature, top_p):
    global tokenizer, peft_model
    
    if tokenizer is None or peft_model is None:
        return "Model loading failed. Please check the logs."
    
    # Construct the prompt
    prompt = system_message
    for user_msg, assistant_msg in history:
        if user_msg:
            prompt += f"\nUser: {user_msg}"
        if assistant_msg:
            prompt += f"\nAssistant: {assistant_msg}"
    prompt += f"\nUser: {message}"
    
    # Tokenize the input prompt
    inputs = tokenizer(prompt, return_tensors="pt", truncation=True)
    
    try:
        outputs = peft_model.generate(
            **inputs,
            max_new_tokens=max_tokens,
            temperature=temperature,
            top_p=top_p,
            do_sample=True
        )
        
        # Get the last message from the generated text
        generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
        def extract_user_content(text):
            """
            Extracts and returns content between <user>...</user> tags in the given text.
            If multiple such sections exist, their contents are concatenated.
            """
            # Compile the pattern with the desired flags
            pattern = re.compile(r'<user>(.*?)</user>', re.IGNORECASE | re.DOTALL)
            matches = re.findall(pattern, text)
            extracted_content = '\n'.join(match.strip() for match in matches)
            return extracted_content
        
        # lines = extract_user_content(generated_text)
        
        return generated_text
        
    except Exception as e:
        return f"Generation error: {e}"

demo = gr.ChatInterface(
    respond,
    additional_inputs=[
        gr.Textbox(value="Take the user input in Hindi language and normalize specific entities, Only including: Dates (any format) Currencies Scientific units, <Example> Exampleinput :  2012–13 में रक्षा सेवाओं के लिए 1,93,407 करोड़ रुपए का प्रावधान किया गया था, जबकि 2011–2012 में यह राशि 1,64,415 करोइ़ थी, Exampleoutput: ट्वेन्टी ट्वेल्व फिफ्टीन में रक्षा सेवाओं के लिए वन करोड़ निनेटी थ्री थाउजेंड फोर हंड्रेड सेवन करोड़ रुपए का प्रावधान किया गया था, जबकि ट्वेन्टी एलेवन ट्वेल्व में यह राशि वन करोड़ सिक्स्टी फोर थाउजेंड फोर हंड्रेड फिफ्टीन करोड़ थी </Example>, Only provide the normalized output with atmost accuracy inside <user> xml tag",label="System message"),
        gr.Slider(minimum=1, maximum=2048, value=128, step=1, label="Max new tokens"),
        gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
        gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)")
    ],
)

if __name__ == "__main__":
    demo.launch()