farid678 commited on
Commit
b814e71
·
verified ·
1 Parent(s): 9a94b51

create app.py file

Browse files
Files changed (1) hide show
  1. app.py +87 -0
app.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import gradio as gr
3
+ from transformers import AutoTokenizer, AutoModelForCausalLM
4
+ from unsloth import FastLanguageModel
5
+ from snac import SNAC
6
+ import torchaudio
7
+ import io
8
+
9
+ # -----------------------------
10
+ # CONFIG
11
+ # -----------------------------
12
+ BASE_MODEL = "unsloth/Orpheus-3B"
13
+ ADAPTER_PATH = "model" # put your adapter files here
14
+ SNAC_MODEL = "snacai/snac_24khz"
15
+
16
+ # -----------------------------
17
+ # LOAD TOKENIZER
18
+ # -----------------------------
19
+ tokenizer = AutoTokenizer.from_pretrained(ADAPTER_PATH, use_fast=True)
20
+
21
+ # -----------------------------
22
+ # LOAD BASE MODEL + LORA
23
+ # -----------------------------
24
+ model = FastLanguageModel.from_pretrained(
25
+ model_name = BASE_MODEL,
26
+ max_seq_length = 4096,
27
+ load_in_4bit = False,
28
+ )
29
+
30
+ model = FastLanguageModel.load_lora(
31
+ model,
32
+ ADAPTER_PATH,
33
+ )
34
+ model.eval()
35
+
36
+ # -----------------------------
37
+ # LOAD SNAC CODEC
38
+ # -----------------------------
39
+ device = "cuda" if torch.cuda.is_available() else "cpu"
40
+ codec = SNAC.from_pretrained(SNAC_MODEL).to(device)
41
+
42
+ # -----------------------------
43
+ # INFERENCE FUNCTION
44
+ # -----------------------------
45
+ def tts_generate(text):
46
+ if not text.strip():
47
+ return None
48
+
49
+ inputs = tokenizer(text, return_tensors="pt").to(device)
50
+
51
+ with torch.no_grad():
52
+ outputs = model.generate(
53
+ **inputs,
54
+ max_new_tokens=1024,
55
+ do_sample=True,
56
+ temperature=0.8,
57
+ top_p=0.9,
58
+ eos_token_id=tokenizer.eos_token_id,
59
+ )
60
+
61
+ # Extract audio codes
62
+ generated_ids = outputs[0][inputs["input_ids"].shape[1]:]
63
+ codes = generated_ids.unsqueeze(0).to(device)
64
+
65
+ # Decode using SNAC
66
+ audio = codec.decode(codes).cpu().squeeze().numpy()
67
+
68
+ # Convert to WAV data for Gradio
69
+ buffer = io.BytesIO()
70
+ torchaudio.save(buffer, torch.tensor(audio).unsqueeze(0), 24000, format="wav")
71
+ buffer.seek(0)
72
+
73
+ return (24000, audio)
74
+
75
+ # -----------------------------
76
+ # GRADIO INTERFACE
77
+ # -----------------------------
78
+ demo = gr.Interface(
79
+ fn=tts_generate,
80
+ inputs=gr.Textbox(label="متن را وارد کنید"),
81
+ outputs=gr.Audio(label="صدای تولید‌شده"),
82
+ title="Unsloth TTS (Orpheus 3B + LoRA)",
83
+ description="متن را وارد کنید تا مدل صدا تولید کند.",
84
+ )
85
+
86
+ if __name__ == "__main__":
87
+ demo.launch()