Akwbw commited on
Commit
2c17bee
·
verified ·
1 Parent(s): 024a9f4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +42 -24
app.py CHANGED
@@ -2,7 +2,7 @@ import os
2
  import torch
3
  from transformers import VitsModel, AutoTokenizer
4
  from fastapi import FastAPI, HTTPException, Header
5
- from fastapi.responses import FileResponse
6
  from fastapi.middleware.cors import CORSMiddleware
7
  import scipy.io.wavfile
8
  import uuid
@@ -10,7 +10,7 @@ import numpy as np
10
 
11
  app = FastAPI()
12
 
13
- # --- CORS Permissions ---
14
  app.add_middleware(
15
  CORSMiddleware,
16
  allow_origins=["*"],
@@ -23,61 +23,79 @@ OUTPUT_DIR = "/tmp"
23
  os.makedirs(OUTPUT_DIR, exist_ok=True)
24
  API_KEY = os.getenv("API_KEY", "MySecretPassword123")
25
 
26
- print("⏳ Loading AI Models... (Thora time lagega)")
27
 
28
- # --- LOADING REAL AI MODELS (Facebook MMS) ---
29
- # Ye models pehli baar run honay par download honge (10-20 seconds)
 
 
 
 
 
30
  try:
31
- # Urdu Model
32
- model_ur = VitsModel.from_pretrained("facebook/mms-tts-urd-script-arabic")
33
- tokenizer_ur = AutoTokenizer.from_pretrained("facebook/mms-tts-urd-script-arabic")
 
34
 
35
- # Hindi Model
 
36
  model_hi = VitsModel.from_pretrained("facebook/mms-tts-hin")
37
  tokenizer_hi = AutoTokenizer.from_pretrained("facebook/mms-tts-hin")
38
 
39
- print("✅ AI Models Loaded Successfully!")
 
40
  except Exception as e:
41
- print(f"❌ Model Loading Error: {e}")
42
 
43
  @app.get("/")
44
  def home():
45
- return {"status": "Online", "message": "Real AI VITS Model Running"}
 
 
46
 
47
  @app.post("/generate")
48
  async def generate_tts(
49
  text: str,
50
- voice_id: str = "urdu", # urdu or hindi
51
  x_api_key: str = Header(None)
52
  ):
53
- # 1. Security Check
54
  if x_api_key != API_KEY:
55
  raise HTTPException(status_code=401, detail="Invalid API Key")
56
 
 
 
 
 
57
  filename = f"{uuid.uuid4()}.wav"
58
  filepath = os.path.join(OUTPUT_DIR, filename)
59
 
60
  try:
61
- # 2. Select Language Model
 
 
 
62
  if "hindi" in voice_id.lower():
63
  inputs = tokenizer_hi(text, return_tensors="pt")
64
- with torch.no_grad():
65
- output = model_hi(**inputs).waveform
66
  else:
67
- # Default to Urdu
68
  inputs = tokenizer_ur(text, return_tensors="pt")
69
- with torch.no_grad():
70
- output = model_ur(**inputs).waveform
 
 
 
71
 
72
- # 3. Save Audio File (WAV format)
73
- # Convert PyTorch tensor to audio file
74
  audio_data = output.cpu().numpy().squeeze()
75
- scipy.io.wavfile.write(filepath, rate=model_ur.config.sampling_rate, data=audio_data)
76
 
77
  return FileResponse(filepath, media_type="audio/wav", filename="ai_audio.wav")
78
 
79
  except Exception as e:
80
- return {"error": str(e)}
81
 
82
  if __name__ == "__main__":
83
  import uvicorn
 
2
  import torch
3
  from transformers import VitsModel, AutoTokenizer
4
  from fastapi import FastAPI, HTTPException, Header
5
+ from fastapi.responses import FileResponse, JSONResponse
6
  from fastapi.middleware.cors import CORSMiddleware
7
  import scipy.io.wavfile
8
  import uuid
 
10
 
11
  app = FastAPI()
12
 
13
+ # --- CORS Permissions (Zaroori) ---
14
  app.add_middleware(
15
  CORSMiddleware,
16
  allow_origins=["*"],
 
23
  os.makedirs(OUTPUT_DIR, exist_ok=True)
24
  API_KEY = os.getenv("API_KEY", "MySecretPassword123")
25
 
26
+ print("⏳ Loading AI Models... (1-2 minute lagenge)")
27
 
28
+ # --- GLOBAL VARIABLES FOR MODELS ---
29
+ model_ur = None
30
+ tokenizer_ur = None
31
+ model_hi = None
32
+ tokenizer_hi = None
33
+
34
+ # --- LOADING REAL AI MODELS (Fixed Names) ---
35
  try:
36
+ # 1. URDU MODEL (Correct Name: facebook/mms-tts-urd)
37
+ print("Downloading Urdu Model...")
38
+ model_ur = VitsModel.from_pretrained("facebook/mms-tts-urd")
39
+ tokenizer_ur = AutoTokenizer.from_pretrained("facebook/mms-tts-urd")
40
 
41
+ # 2. HINDI MODEL (Correct Name: facebook/mms-tts-hin)
42
+ print("Downloading Hindi Model...")
43
  model_hi = VitsModel.from_pretrained("facebook/mms-tts-hin")
44
  tokenizer_hi = AutoTokenizer.from_pretrained("facebook/mms-tts-hin")
45
 
46
+ print("✅ All AI Models Loaded Successfully!")
47
+
48
  except Exception as e:
49
+ print(f"❌ CRITICAL ERROR LOADING MODELS: {e}")
50
 
51
  @app.get("/")
52
  def home():
53
+ if model_ur is None:
54
+ return {"status": "Error", "message": "Models failed to load. Check Logs."}
55
+ return {"status": "Online", "message": "Real AI VITS Model Ready (Fixed)"}
56
 
57
  @app.post("/generate")
58
  async def generate_tts(
59
  text: str,
60
+ voice_id: str = "urdu",
61
  x_api_key: str = Header(None)
62
  ):
63
+ # 1. Security
64
  if x_api_key != API_KEY:
65
  raise HTTPException(status_code=401, detail="Invalid API Key")
66
 
67
+ # 2. Check if models exist
68
+ if model_ur is None or model_hi is None:
69
+ return JSONResponse(status_code=500, content={"error": "Models not loaded yet. Check Server Logs."})
70
+
71
  filename = f"{uuid.uuid4()}.wav"
72
  filepath = os.path.join(OUTPUT_DIR, filename)
73
 
74
  try:
75
+ # 3. Generate Logic
76
+ inputs = None
77
+ model = None
78
+
79
  if "hindi" in voice_id.lower():
80
  inputs = tokenizer_hi(text, return_tensors="pt")
81
+ model = model_hi
 
82
  else:
83
+ # Urdu Default
84
  inputs = tokenizer_ur(text, return_tensors="pt")
85
+ model = model_ur
86
+
87
+ # 4. Create Waveform (No Internet Needed)
88
+ with torch.no_grad():
89
+ output = model(**inputs).waveform
90
 
91
+ # 5. Save File
 
92
  audio_data = output.cpu().numpy().squeeze()
93
+ scipy.io.wavfile.write(filepath, rate=model.config.sampling_rate, data=audio_data)
94
 
95
  return FileResponse(filepath, media_type="audio/wav", filename="ai_audio.wav")
96
 
97
  except Exception as e:
98
+ return JSONResponse(status_code=500, content={"error": str(e)})
99
 
100
  if __name__ == "__main__":
101
  import uvicorn