d3dname commited on
Commit
af03d54
·
verified ·
1 Parent(s): 550aabf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +47 -10
app.py CHANGED
@@ -6,9 +6,10 @@ import time
6
  import torch
7
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
8
  import streamlit as st
9
-
10
  from transformers import AutoModelForSeq2SeqLM
11
- import torch
 
12
 
13
  from threading import Thread
14
  os.environ["COQUI_TOS_AGREED"] = "1"
@@ -20,6 +21,33 @@ HF_TOKEN = os.environ.get("HF_TOKEN", None)
20
  # Set the device to GPU or CPU
21
  device = "cuda" if torch.cuda.is_available() else "cpu"
22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  # Load the tokenizer and model
24
  tokenizer = AutoTokenizer.from_pretrained("merve/chatgpt-prompts-bart-long")
25
  model = AutoModelForSeq2SeqLM.from_pretrained("merve/chatgpt-prompts-bart-long", from_tf=True).to("cuda" if torch.cuda.is_available() else "cpu")
@@ -130,15 +158,24 @@ with right:
130
  st.markdown('''<h3><i class="fa fa-pencil"></i> Form 2</h3>''', unsafe_allow_html=True)
131
 
132
  # Box 3: Form 2
133
- prompt2 = st.text_input("Enter Prompt", key="prompt2")
134
- image_url2 = st.text_input("Enter Image URL", key="image_url2")
135
- if st.button("Submit Form 2", key="submit2"):
136
- payload = {"prompt": prompt2, "image_url": image_url2}
137
- response = requests.post("https://d3ndnam3-hf.space/api", json=payload)
138
- if response.status_code == 200:
139
- st.write(f"**Response:** {response.json().get('response', 'No response')}")
 
 
 
140
  else:
141
- st.write("Failed to get a response")
 
 
 
 
 
 
142
 
143
  # End of Box 3 and third Carousel Item
144
  st.markdown('''<h3><i class="fa fa-pencil"></i> Form 3</h3>''', unsafe_allow_html=True)
 
6
  import torch
7
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
8
  import streamlit as st
9
+ import pytube as pt
10
  from transformers import AutoModelForSeq2SeqLM
11
+ MODEL_NAME = "drinktoomuchsax/whisper-small-hi"
12
+ lang = "en"
13
 
14
  from threading import Thread
15
  os.environ["COQUI_TOS_AGREED"] = "1"
 
21
  # Set the device to GPU or CPU
22
  device = "cuda" if torch.cuda.is_available() else "cpu"
23
 
24
+
25
+ ### Whisper Start
26
+
27
+ pipe = pipeline(
28
+ task="automatic-speech-recognition",
29
+ model=MODEL_NAME,
30
+ chunk_length_s=30,
31
+ device=device,
32
+ )
33
+ pipe.model.config.forced_decoder_ids = pipe.tokenizer.get_decoder_prompt_ids(language=lang, task="transcribe")
34
+
35
+ # Transcription function
36
+ def transcribe(file):
37
+ text = pipe(file)["text"]
38
+ return text
39
+
40
+ # YouTube transcription function
41
+ def yt_transcribe(yt_url):
42
+ yt = pt.YouTube(yt_url)
43
+ stream = yt.streams.filter(only_audio=True)[0]
44
+ stream.download(filename="audio.mp3")
45
+ text = pipe("audio.mp3")["text"]
46
+ return text
47
+
48
+ ### Whisper END
49
+
50
+
51
  # Load the tokenizer and model
52
  tokenizer = AutoTokenizer.from_pretrained("merve/chatgpt-prompts-bart-long")
53
  model = AutoModelForSeq2SeqLM.from_pretrained("merve/chatgpt-prompts-bart-long", from_tf=True).to("cuda" if torch.cuda.is_available() else "cpu")
 
158
  st.markdown('''<h3><i class="fa fa-pencil"></i> Form 2</h3>''', unsafe_allow_html=True)
159
 
160
  # Box 3: Form 2
161
+ uploaded_file = st.file_uploader("Upload an audio file", type=["mp3", "wav", "flac", "aac"])
162
+ microphone_input = st.audio("Record audio using microphone", format="audio/wav")
163
+ if st.button("Transcribe"):
164
+ if microphone_input and uploaded_file:
165
+ st.warning("WARNING: You've uploaded an audio file and used the microphone. The recorded file from the microphone will be used, and the uploaded audio will be discarded.")
166
+ file_to_transcribe = microphone_input
167
+ elif microphone_input:
168
+ file_to_transcribe = microphone_input
169
+ elif uploaded_file:
170
+ file_to_transcribe = uploaded_file
171
  else:
172
+ st.error("ERROR: You have to either use the microphone or upload an audio file")
173
+ file_to_transcribe = None
174
+
175
+ if file_to_transcribe:
176
+ with st.spinner("Transcribing..."):
177
+ transcription = transcribe(file_to_transcribe)
178
+ st.write(transcription)
179
 
180
  # End of Box 3 and third Carousel Item
181
  st.markdown('''<h3><i class="fa fa-pencil"></i> Form 3</h3>''', unsafe_allow_html=True)