Josephanthraper commited on
Commit
82d0705
·
verified ·
1 Parent(s): e3f3705

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +94 -30
app.py CHANGED
@@ -1,48 +1,112 @@
1
- # app.py
2
- import gradio as gr
3
- from whisper_jax import FlaxWhisperPipline
4
  import jax.numpy as jnp
 
 
 
5
  from pydub import AudioSegment
6
- import os
 
 
 
7
 
8
- # Load Whisper JAX model once (on startup)
9
  asr_pipeline = FlaxWhisperPipline(
10
  "parthiv11/indic_whisper_nodcil",
11
  dtype=jnp.bfloat16
12
  )
13
 
14
- # Convert audio to wav (most stable for Whisper)
15
- def convert_to_wav(audio_file):
16
- wav_path = audio_file.rsplit(".", 1)[0] + ".wav"
17
- sound = AudioSegment.from_file(audio_file)
18
  sound.export(wav_path, format="wav")
19
  return wav_path
20
 
21
- # Function connected to Gradio
22
- def transcribe(audio_file):
23
- if audio_file is None:
24
- return "Please upload an audio file."
25
 
 
 
26
  try:
27
- wav_file = convert_to_wav(audio_file)
28
- result = asr_pipeline(wav_file)
 
 
 
 
 
29
 
30
- # Clean up temp wav file if created
31
- if wav_file != audio_file and os.path.exists(wav_file):
32
- os.remove(wav_file)
 
 
 
 
 
33
 
34
- return result["text"] if isinstance(result, dict) else result
35
  except Exception as e:
36
- return f"Error processing audio: {str(e)}"
37
-
38
- # Build UI
39
- demo = gr.Interface(
40
- fn=transcribe,
41
- inputs=gr.Audio(sources=["microphone", "upload"], type="filepath"),
42
- outputs=gr.Textbox(label="Transcription", lines=10),
43
- title="Whisper (JAX)",
44
- description="Upload or record Hindi speech and get transcription"
45
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
  if __name__ == "__main__":
48
- demo.launch(server_name="0.0.0.0", server_port=7860, show_error=True)
 
1
+ # fastapi_app.py
2
+ import os
 
3
  import jax.numpy as jnp
4
+ from fastapi import FastAPI, UploadFile, File, HTTPException
5
+ from fastapi.responses import JSONResponse, HTMLResponse
6
+ from whisper_jax import FlaxWhisperPipline
7
  from pydub import AudioSegment
8
+ import uvicorn
9
+ import tempfile
10
+
11
+ app = FastAPI(title="Whisper JAX API", description="Transcribe Hindi speech using Whisper JAX", version="1.0")
12
 
13
+ # Load Whisper JAX model once on startup
14
  asr_pipeline = FlaxWhisperPipline(
15
  "parthiv11/indic_whisper_nodcil",
16
  dtype=jnp.bfloat16
17
  )
18
 
19
+ # Convert audio to wav
20
+ def convert_to_wav(input_path: str) -> str:
21
+ wav_path = input_path.rsplit(".", 1)[0] + ".wav"
22
+ sound = AudioSegment.from_file(input_path)
23
  sound.export(wav_path, format="wav")
24
  return wav_path
25
 
26
+ @app.get("/")
27
+ async def root():
28
+ return {"message": "Whisper JAX API is running!"}
 
29
 
30
+ @app.post("/transcribe")
31
+ async def transcribe(file: UploadFile = File(...)):
32
  try:
33
+ suffix = os.path.splitext(file.filename)[-1]
34
+ with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as temp_file:
35
+ temp_file.write(await file.read())
36
+ temp_path = temp_file.name
37
+
38
+ wav_path = convert_to_wav(temp_path)
39
+ result = asr_pipeline(wav_path)
40
 
41
+ # Cleanup
42
+ if os.path.exists(temp_path):
43
+ os.remove(temp_path)
44
+ if wav_path != temp_path and os.path.exists(wav_path):
45
+ os.remove(wav_path)
46
+
47
+ transcription = result["text"] if isinstance(result, dict) else result
48
+ return JSONResponse(content={"transcription": transcription})
49
 
 
50
  except Exception as e:
51
+ raise HTTPException(status_code=500, detail=f"Error processing audio: {str(e)}")
52
+
53
+ @app.get("/ui")
54
+ async def serve_ui():
55
+ html_content = """
56
+ <!DOCTYPE html>
57
+ <html>
58
+ <head>
59
+ <title>Whisper JAX Transcription</title>
60
+ <style>
61
+ body { font-family: Arial, sans-serif; max-width: 600px; margin: 30px auto; padding: 20px; background: #f4f6f8; }
62
+ h2 { text-align: center; }
63
+ .card { background: white; padding: 20px; border-radius: 12px; box-shadow: 0 2px 8px rgba(0,0,0,0.1); }
64
+ #output { margin-top: 20px; padding: 15px; background: #fafafa; border: 1px solid #ddd; border-radius: 8px; min-height: 100px; }
65
+ button { padding: 10px 20px; border: none; border-radius: 8px; background: #06b6d4; color: white; cursor: pointer; }
66
+ button:hover { background: #0891b2; }
67
+ </style>
68
+ </head>
69
+ <body>
70
+ <div class="card">
71
+ <h2>Whisper (JAX) Speech-to-Text</h2>
72
+ <form id="uploadForm">
73
+ <input type="file" id="audioFile" name="file" accept="audio/*" required />
74
+ <button type="submit">Transcribe</button>
75
+ </form>
76
+ <div id="output">Transcription will appear here...</div>
77
+ </div>
78
+
79
+ <script>
80
+ document.getElementById("uploadForm").addEventListener("submit", async function(e) {
81
+ e.preventDefault();
82
+ const fileInput = document.getElementById("audioFile");
83
+ if (!fileInput.files.length) return;
84
+
85
+ const formData = new FormData();
86
+ formData.append("file", fileInput.files[0]);
87
+
88
+ document.getElementById("output").innerText = "Processing...";
89
 
90
+ try {
91
+ const response = await fetch("/transcribe", {
92
+ method: "POST",
93
+ body: formData
94
+ });
95
+ const data = await response.json();
96
+ if (data.transcription) {
97
+ document.getElementById("output").innerText = data.transcription;
98
+ } else {
99
+ document.getElementById("output").innerText = "Error: " + JSON.stringify(data);
100
+ }
101
+ } catch (err) {
102
+ document.getElementById("output").innerText = "Failed: " + err.message;
103
+ }
104
+ });
105
+ </script>
106
+ </body>
107
+ </html>
108
+ """
109
+ return HTMLResponse(content=html_content)
110
+
111
  if __name__ == "__main__":
112
+ uvicorn.run("fastapi_app:app", host="0.0.0.0", port=7860, reload=True)