Josephanthraper commited on
Commit
b96a003
·
verified ·
1 Parent(s): 57ea62f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -104
app.py CHANGED
@@ -1,112 +1,30 @@
1
- # fastapi_app.py
2
- import os
3
- import jax.numpy as jnp
4
- from fastapi import FastAPI, UploadFile, File, HTTPException
5
- from fastapi.responses import JSONResponse, HTMLResponse
6
  from whisper_jax import FlaxWhisperPipline
7
- from pydub import AudioSegment
8
- import uvicorn
9
- import tempfile
10
-
11
- app = FastAPI(title="Whisper JAX API", description="Transcribe Hindi speech using Whisper JAX", version="1.0")
12
 
13
- # Load Whisper JAX model once on startup
14
- asr_pipeline = FlaxWhisperPipline(
15
  "parthiv11/indic_whisper_nodcil",
16
  dtype=jnp.bfloat16
17
  )
18
 
19
- # Convert audio to wav
20
- def convert_to_wav(input_path: str) -> str:
21
- wav_path = input_path.rsplit(".", 1)[0] + ".wav"
22
- sound = AudioSegment.from_file(input_path)
23
- sound.export(wav_path, format="wav")
24
- return wav_path
25
-
26
- @app.get("/")
27
- async def root():
28
- return {"message": "Whisper JAX API is running!"}
29
-
30
- @app.post("/transcribe")
31
- async def transcribe(file: UploadFile = File(...)):
32
- try:
33
- suffix = os.path.splitext(file.filename)[-1]
34
- with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as temp_file:
35
- temp_file.write(await file.read())
36
- temp_path = temp_file.name
37
-
38
- wav_path = convert_to_wav(temp_path)
39
- result = asr_pipeline(wav_path)
40
-
41
- # Cleanup
42
- if os.path.exists(temp_path):
43
- os.remove(temp_path)
44
- if wav_path != temp_path and os.path.exists(wav_path):
45
- os.remove(wav_path)
46
-
47
- transcription = result["text"] if isinstance(result, dict) else result
48
- return JSONResponse(content={"transcription": transcription})
49
-
50
- except Exception as e:
51
- raise HTTPException(status_code=500, detail=f"Error processing audio: {str(e)}")
52
-
53
- @app.get("/ui")
54
- async def serve_ui():
55
- html_content = """
56
- <!DOCTYPE html>
57
- <html>
58
- <head>
59
- <title>Whisper JAX Transcription</title>
60
- <style>
61
- body { font-family: Arial, sans-serif; max-width: 600px; margin: 30px auto; padding: 20px; background: #f4f6f8; }
62
- h2 { text-align: center; }
63
- .card { background: white; padding: 20px; border-radius: 12px; box-shadow: 0 2px 8px rgba(0,0,0,0.1); }
64
- #output { margin-top: 20px; padding: 15px; background: #fafafa; border: 1px solid #ddd; border-radius: 8px; min-height: 100px; }
65
- button { padding: 10px 20px; border: none; border-radius: 8px; background: #06b6d4; color: white; cursor: pointer; }
66
- button:hover { background: #0891b2; }
67
- </style>
68
- </head>
69
- <body>
70
- <div class="card">
71
- <h2>Whisper (JAX) Speech-to-Text</h2>
72
- <form id="uploadForm">
73
- <input type="file" id="audioFile" name="file" accept="audio/*" required />
74
- <button type="submit">Transcribe</button>
75
- </form>
76
- <div id="output">Transcription will appear here...</div>
77
- </div>
78
-
79
- <script>
80
- document.getElementById("uploadForm").addEventListener("submit", async function(e) {
81
- e.preventDefault();
82
- const fileInput = document.getElementById("audioFile");
83
- if (!fileInput.files.length) return;
84
-
85
- const formData = new FormData();
86
- formData.append("file", fileInput.files[0]);
87
-
88
- document.getElementById("output").innerText = "Processing...";
89
-
90
- try {
91
- const response = await fetch("/transcribe", {
92
- method: "POST",
93
- body: formData
94
- });
95
- const data = await response.json();
96
- if (data.transcription) {
97
- document.getElementById("output").innerText = data.transcription;
98
- } else {
99
- document.getElementById("output").innerText = "Error: " + JSON.stringify(data);
100
- }
101
- } catch (err) {
102
- document.getElementById("output").innerText = "Failed: " + err.message;
103
- }
104
- });
105
- </script>
106
- </body>
107
- </html>
108
- """
109
- return HTMLResponse(content=html_content)
110
 
 
 
 
 
 
 
 
 
 
 
 
 
111
  if __name__ == "__main__":
112
- uvicorn.run("fastapi_app:app", host="0.0.0.0", port=7860, reload=True)
 
1
+ # app.py
2
+ import gradio as gr
 
 
 
3
  from whisper_jax import FlaxWhisperPipline
4
+ import jax.numpy as jnp
 
 
 
 
5
 
6
+ # Load Whisper JAX model once (on startup)
7
+ pipeline = FlaxWhisperPipline(
8
  "parthiv11/indic_whisper_nodcil",
9
  dtype=jnp.bfloat16
10
  )
11
 
12
+ # Function connected to Gradio
13
+ def transcribe(audio_file):
14
+ if audio_file is None:
15
+ return "Please upload an audio file."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
+ result = pipeline(audio_file)
18
+ return result["text"] if isinstance(result, dict) else result
19
+
20
+ # Build UI (this part comes from Playground export)
21
+ demo = gr.Interface(
22
+ fn=transcribe,
23
+ inputs=gr.Audio(type="filepath"),
24
+ outputs="text",
25
+ title="Hindi Whisper ",
26
+ description="Upload or record Hindi speech and get transcription"
27
+ )
28
+
29
  if __name__ == "__main__":
30
+ demo.launch(server_name="0.0.0.0", server_port=7860, show_error=True)