KGSAGAR commited on
Commit
9c2dbe4
·
verified ·
1 Parent(s): ed49a22

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +31 -28
app.py CHANGED
@@ -2,28 +2,25 @@ import gradio as gr
2
  from transformers import AutoTokenizer, AutoModelForCausalLM
3
  from peft import PeftModel
4
  from huggingface_hub import InferenceClient
 
 
5
 
6
  """
7
  For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
8
  """
9
 
10
- # Load the tokenizer
11
- tokenizer = AutoTokenizer.from_pretrained("sarvamai/sarvam-1")
12
-
13
- # Load the base model
14
- base_model = AutoModelForCausalLM.from_pretrained("sarvamai/sarvam-1")
15
-
16
- # Load the PEFT adapter
17
- peft_model = PeftModel.from_pretrained(base_model, "KGSAGAR/Sarvam-1-text-normalization-3r")
18
- peft_model = peft_model.merge_and_unload()
19
-
20
-
21
- # client = InferenceClient(peft_model)
22
-
23
-
24
- import re
25
- import torch
26
- from transformers import AutoTokenizer
27
 
28
  def respond(
29
  message,
@@ -32,8 +29,6 @@ def respond(
32
  max_tokens,
33
  temperature,
34
  top_p,
35
- peft_model,
36
- tokenizer_name,
37
  ):
38
  """
39
  Generates a response based on the user message and history using the provided PEFT model.
@@ -45,12 +40,14 @@ def respond(
45
  max_tokens (int): The maximum number of tokens to generate.
46
  temperature (float): The temperature parameter for generation.
47
  top_p (float): The top_p parameter for nucleus sampling.
48
- peft_model: The pre-trained fine-tuned model for generation.
49
- tokenizer_name (str): The name or path of the tokenizer.
50
 
51
  Yields:
52
  str: The generated response up to the current token.
53
  """
 
 
 
 
54
 
55
  # Construct the prompt
56
  prompt = system_message
@@ -63,14 +60,20 @@ def respond(
63
 
64
  # Tokenize the input prompt
65
  inputs = tokenizer(prompt, return_tensors="pt", truncation=True)
 
66
  # Generate the output
67
- outputs = peft_model.generate(
68
- **inputs,
69
- max_new_tokens=max_tokens,
70
- temperature=temperature,
71
- top_p=top_p,
72
- do_sample=True # Enable sampling for more diverse outputs
73
- )
 
 
 
 
 
74
 
75
  # Decode the generated tokens
76
  generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
 
2
  from transformers import AutoTokenizer, AutoModelForCausalLM
3
  from peft import PeftModel
4
  from huggingface_hub import InferenceClient
5
+ import re
6
+ import torch
7
 
8
  """
9
  For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
10
  """
11
 
12
+ # Model and tokenizer loading (outside the respond function)
13
+ try:
14
+ tokenizer = AutoTokenizer.from_pretrained("sarvamai/sarvam-1")
15
+ base_model = AutoModelForCausalLM.from_pretrained("sarvamai/sarvam-1")
16
+ peft_model = PeftModel.from_pretrained(base_model, "KGSAGAR/Sarvam-1-text-normalization-3r")
17
+ peft_model = peft_model.merge_and_unload()
18
+ print("Model loaded successfully!") # Add this line
19
+ except Exception as e:
20
+ print(f"Error loading model: {e}")
21
+ tokenizer = None
22
+ base_model = None
23
+ peft_model = None
 
 
 
 
 
24
 
25
  def respond(
26
  message,
 
29
  max_tokens,
30
  temperature,
31
  top_p,
 
 
32
  ):
33
  """
34
  Generates a response based on the user message and history using the provided PEFT model.
 
40
  max_tokens (int): The maximum number of tokens to generate.
41
  temperature (float): The temperature parameter for generation.
42
  top_p (float): The top_p parameter for nucleus sampling.
 
 
43
 
44
  Yields:
45
  str: The generated response up to the current token.
46
  """
47
+ global tokenizer, peft_model #access global variables
48
+ if tokenizer is None or peft_model is None:
49
+ yield "Model loading failed. Please check the logs."
50
+ return
51
 
52
  # Construct the prompt
53
  prompt = system_message
 
60
 
61
  # Tokenize the input prompt
62
  inputs = tokenizer(prompt, return_tensors="pt", truncation=True)
63
+
64
  # Generate the output
65
+ try:
66
+ outputs = peft_model.generate(
67
+ **inputs,
68
+ max_new_tokens=max_tokens,
69
+ temperature=temperature,
70
+ top_p=top_p,
71
+ do_sample=True # Enable sampling for more diverse outputs
72
+ )
73
+ except Exception as e:
74
+ yield f"Generation error: {e}"
75
+ return
76
+
77
 
78
  # Decode the generated tokens
79
  generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)