sainathBelagavi commited on
Commit
e123732
·
verified ·
1 Parent(s): 33172ac

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +66 -36
app.py CHANGED
@@ -2,21 +2,44 @@
2
  import gradio as gr
3
  import json
4
  import re
 
5
  from datetime import datetime
6
  from transformers import AutoTokenizer, AutoModelForCausalLM
7
  import torch
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
  class TranscriptAnalyzer:
10
  def __init__(self):
11
- # Initialize the model and tokenizer
12
- self.model_name = "mistralai/Mistral-7B-Instruct-v0.2"
13
- self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
14
- self.model = AutoModelForCausalLM.from_pretrained(
15
- self.model_name,
16
- torch_dtype=torch.float16,
17
- device_map="auto"
18
- )
19
-
 
 
 
 
 
 
 
 
 
 
20
  def extract_dates(self, text: str):
21
  date_patterns = [
22
  r'\d{1,2}[-/]\d{1,2}[-/]\d{2,4}',
@@ -83,34 +106,40 @@ FOLLOW UP:
83
  - Pending items [/INST]</s>"""
84
 
85
  def analyze_transcript(self, transcript: str):
86
- # Generate prompt
87
- prompt = self.generate_prompt(transcript)
88
-
89
- # Tokenize input
90
- inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
91
-
92
- # Generate response
93
- with torch.no_grad():
94
- outputs = self.model.generate(
95
- **inputs,
96
- max_new_tokens=1000,
97
- temperature=0.1,
98
- do_sample=True,
99
- pad_token_id=self.tokenizer.eos_token_id
100
- )
101
-
102
- # Decode response
103
- response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
104
-
105
- # Extract the assistant's response (after the prompt)
106
- response = response.split("[/INST]")[-1].strip()
107
-
108
- return response
 
 
 
109
 
110
  def process_transcript(transcript: str):
111
- analyzer = TranscriptAnalyzer()
112
- analysis = analyzer.analyze_transcript(transcript)
113
- return analysis
 
 
 
114
 
115
  # Create Gradio interface
116
  iface = gr.Interface(
@@ -135,4 +164,5 @@ iface = gr.Interface(
135
  )
136
 
137
  # Launch the app
138
- iface.launch()
 
 
2
  import gradio as gr
3
  import json
4
  import re
5
+ import os
6
  from datetime import datetime
7
  from transformers import AutoTokenizer, AutoModelForCausalLM
8
  import torch
9
+ from huggingface_hub import login
10
+
11
+ # First, login with the Hugging Face token from secrets
12
+ try:
13
+ hf_token = os.environ.get('HUGGINGFACE_TOKEN')
14
+ if hf_token:
15
+ login(token=hf_token)
16
+ else:
17
+ raise ValueError("HUGGINGFACE_TOKEN not found in environment variables")
18
+ except Exception as e:
19
+ print(f"Error during Hugging Face login: {str(e)}")
20
+ raise
21
 
22
  class TranscriptAnalyzer:
23
  def __init__(self):
24
+ try:
25
+ # Initialize the model and tokenizer with auth token
26
+ self.model_name = "mistralai/Mistral-7B-Instruct-v0.2"
27
+ self.tokenizer = AutoTokenizer.from_pretrained(
28
+ self.model_name,
29
+ token=hf_token,
30
+ trust_remote_code=True
31
+ )
32
+ self.model = AutoModelForCausalLM.from_pretrained(
33
+ self.model_name,
34
+ token=hf_token,
35
+ torch_dtype=torch.float16,
36
+ device_map="auto",
37
+ trust_remote_code=True
38
+ )
39
+ except Exception as e:
40
+ print(f"Error initializing model: {str(e)}")
41
+ raise
42
+
43
  def extract_dates(self, text: str):
44
  date_patterns = [
45
  r'\d{1,2}[-/]\d{1,2}[-/]\d{2,4}',
 
106
  - Pending items [/INST]</s>"""
107
 
108
  def analyze_transcript(self, transcript: str):
109
+ try:
110
+ # Generate prompt
111
+ prompt = self.generate_prompt(transcript)
112
+
113
+ # Tokenize input
114
+ inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
115
+
116
+ # Generate response
117
+ with torch.no_grad():
118
+ outputs = self.model.generate(
119
+ **inputs,
120
+ max_new_tokens=1000,
121
+ temperature=0.1,
122
+ do_sample=True,
123
+ pad_token_id=self.tokenizer.eos_token_id
124
+ )
125
+
126
+ # Decode response
127
+ response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
128
+
129
+ # Extract the assistant's response (after the prompt)
130
+ response = response.split("[/INST]")[-1].strip()
131
+
132
+ return response
133
+ except Exception as e:
134
+ return f"Error analyzing transcript: {str(e)}"
135
 
136
  def process_transcript(transcript: str):
137
+ try:
138
+ analyzer = TranscriptAnalyzer()
139
+ analysis = analyzer.analyze_transcript(transcript)
140
+ return analysis
141
+ except Exception as e:
142
+ return f"Error processing transcript: {str(e)}"
143
 
144
  # Create Gradio interface
145
  iface = gr.Interface(
 
164
  )
165
 
166
  # Launch the app
167
+ if __name__ == "__main__":
168
+ iface.launch()