d3dname commited on
Commit
229fd6f
·
verified ·
1 Parent(s): 10b4193

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -58
app.py CHANGED
@@ -3,13 +3,7 @@ import streamlit.components.v1 as components
3
  import requests
4
  import os
5
  import time
6
- import torch
7
- from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
8
- from transformers import pipeline
9
  import streamlit as st
10
- import pytube as pt
11
- from transformers import AutoModelForSeq2SeqLM
12
-
13
  from streamlit_mic_recorder import mic_recorder
14
  MODEL_NAME = "drinktoomuchsax/whisper-small-hi"
15
  lang = "en"
@@ -27,50 +21,6 @@ BASETEN_KEY = os.environ.get("BASETEN_KEY", None)
27
  device = "cuda" if torch.cuda.is_available() else "cpu"
28
 
29
 
30
- ### Whisper Start
31
-
32
- pipe = pipeline(
33
- task="automatic-speech-recognition",
34
- model=MODEL_NAME,
35
- chunk_length_s=30,
36
- device=device,
37
- )
38
- pipe.model.config.forced_decoder_ids = pipe.tokenizer.get_decoder_prompt_ids(language=lang, task="transcribe")
39
-
40
- # Transcription function
41
- def transcribe(file_path):
42
- # Load the audio file
43
- inputs = tokenizer(file_path, return_tensors="pt", padding="longest", truncation=True)
44
- inputs = {key: value.to(device) for key, value in inputs.items()}
45
-
46
- # Pass the inputs and the attention_mask to the model
47
- generated_ids = pipe.model.generate(inputs["input_ids"], attention_mask=inputs["attention_mask"], max_length=1000)
48
- transcription = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
49
-
50
- return transcription
51
-
52
- # YouTube transcription function
53
- def yt_transcribe(yt_url):
54
- yt = pt.YouTube(yt_url)
55
- stream = yt.streams.filter(only_audio=True)[0]
56
- stream.download(filename="audio.mp3")
57
- transcription = transcribe("audio.mp3")
58
- return transcription
59
-
60
- ### Whisper END
61
-
62
-
63
- # Load the tokenizer and model
64
- tokenizer = AutoTokenizer.from_pretrained("merve/chatgpt-prompts-bart-long")
65
- model = AutoModelForSeq2SeqLM.from_pretrained("merve/chatgpt-prompts-bart-long", from_tf=True).to("cuda" if torch.cuda.is_available() else "cpu")
66
-
67
- # Function to generate the prompt based on the persona
68
- def generate(prompt):
69
- batch = tokenizer(prompt, return_tensors="pt").to(model.device)
70
- generated_ids = model.generate(batch["input_ids"], max_new_tokens=150)
71
- output = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
72
- return output[0]
73
-
74
 
75
  #st.set_page_config(layout="wide")
76
  # Load custom CSS to integrate Bootstrap, Font Awesome, and Google Fonts
@@ -157,13 +107,11 @@ with lr:
157
  if st.button("Generate Prompt"):
158
  if persona:
159
  with st.spinner("Generating..."):
160
- result = generate(persona)
161
  st.text_area("Generated Prompt", value=result, height=200)
162
  else:
163
  st.error("Please enter a persona to generate a prompt.")
164
 
165
-
166
-
167
  with rl:
168
  # End of Box 2 and second Carousel Item
169
  st.markdown('''<h3><i class="fa fa-pencil"></i> Transcribe </h3>''', unsafe_allow_html=True)
@@ -186,11 +134,8 @@ with rl:
186
  with open("temp_recording.wav", "wb") as f:
187
  f.write(audio["bytes"])
188
  with st.spinner("Transcribing..."):
189
- transcription = transcribe("temp_recording.wav")
190
- st.text_area("Transcription", transcription, height=200)
191
- elif uploaded_file is not None:
192
- with st.spinner("Transcribing..."):
193
- transcription = transcribe(uploaded_file)
194
  st.text_area("Transcription", transcription, height=200)
195
  else:
196
  st.error("Please record audio or upload a file to transcribe.")
 
3
  import requests
4
  import os
5
  import time
 
 
 
6
  import streamlit as st
 
 
 
7
  from streamlit_mic_recorder import mic_recorder
8
  MODEL_NAME = "drinktoomuchsax/whisper-small-hi"
9
  lang = "en"
 
21
  device = "cuda" if torch.cuda.is_available() else "cpu"
22
 
23
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
  #st.set_page_config(layout="wide")
26
  # Load custom CSS to integrate Bootstrap, Font Awesome, and Google Fonts
 
107
  if st.button("Generate Prompt"):
108
  if persona:
109
  with st.spinner("Generating..."):
110
+ result = "Test"
111
  st.text_area("Generated Prompt", value=result, height=200)
112
  else:
113
  st.error("Please enter a persona to generate a prompt.")
114
 
 
 
115
  with rl:
116
  # End of Box 2 and second Carousel Item
117
  st.markdown('''<h3><i class="fa fa-pencil"></i> Transcribe </h3>''', unsafe_allow_html=True)
 
134
  with open("temp_recording.wav", "wb") as f:
135
  f.write(audio["bytes"])
136
  with st.spinner("Transcribing..."):
137
+ #transcription = transcribe("temp_recording.wav")
138
+ #need to send the data here
 
 
 
139
  st.text_area("Transcription", transcription, height=200)
140
  else:
141
  st.error("Please record audio or upload a file to transcribe.")