suhyun12 commited on
Commit
a246a1f
·
verified ·
1 Parent(s): 940e061

안드로이드 앱 연동, ngrok 주소 사용

Browse files
Files changed (1) hide show
  1. app_v3/app_v3.py +160 -0
app_v3/app_v3.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from prometheus_fastapi_instrumentator import Instrumentator
2
+ from fastapi import FastAPI, UploadFile, File, Request, Response
3
+ from fastapi.responses import JSONResponse
4
+ from fastapi.middleware.cors import CORSMiddleware
5
+ from transformers import (
6
+ PreTrainedTokenizerFast,
7
+ BartForConditionalGeneration,
8
+ WhisperProcessor,
9
+ WhisperForConditionalGeneration,
10
+ AutoConfig
11
+ )
12
+ import torch
13
+ import torchaudio
14
+ import os
15
+ import random
16
+ from pydub import AudioSegment
17
+ from prometheus_client import Counter
18
+
19
+ app = FastAPI()
20
+ Instrumentator().instrument(app).expose(app)
21
+
22
+ # ▶️ Custom Network Metrics
23
+ request_bytes = Counter("request_bytes_total", "Total request bytes", ["path"])
24
+ response_bytes = Counter("response_bytes_total", "Total response bytes", ["path"])
25
+
26
+ @app.middleware("http")
27
+ async def track_bytes(request: Request, call_next):
28
+ path = request.url.path
29
+ req_len = int(request.headers.get("content-length", 0))
30
+ request_bytes.labels(path=path).inc(req_len)
31
+
32
+ response = await call_next(request)
33
+
34
+ body = b""
35
+ async for chunk in response.body_iterator:
36
+ body += chunk
37
+ resp_len = len(body)
38
+
39
+ response_bytes.labels(path=path).inc(resp_len)
40
+ return Response(content=body, status_code=response.status_code, headers=dict(response.headers), media_type=response.media_type)
41
+
42
+ # ▶️ Device
43
+ device = "cuda" if torch.cuda.is_available() else "cpu"
44
+
45
+ # ▶️ CORS
46
+ app.add_middleware(
47
+ CORSMiddleware,
48
+ allow_origins=["*"],
49
+ allow_credentials=True,
50
+ allow_methods=["*"],
51
+ allow_headers=["*"],
52
+ )
53
+
54
+ # ▶️ Whisper
55
+ #local_model_path = r"D:\\whisper\\whisper_partial_ep16\\step30000-36000"
56
+ #whisper_model = WhisperForConditionalGeneration.from_pretrained(local_model_path).to(device)
57
+ whisper_model = WhisperForConditionalGeneration.from_pretrained(
58
+ "jeju-potato/jeju_potato_models", subfolder="whisper_v1_ep16"
59
+ ).to(device)
60
+ whisper_processor = WhisperProcessor.from_pretrained("openai/whisper-base")
61
+ whisper_model.config.forced_decoder_ids = None
62
+ whisper_model.generation_config.forced_decoder_ids = None
63
+
64
+ # ▶️ KoBART 1
65
+ first_tokenizer = PreTrainedTokenizerFast.from_pretrained("jeju-potato/jeju_potato_models", subfolder="first_kobart_v3")
66
+ first_model = BartForConditionalGeneration.from_pretrained("jeju-potato/jeju_potato_models", subfolder="first_kobart_v3").to(device)
67
+ first_model.config.forced_decoder_ids = None
68
+ first_model.generation_config.forced_decoder_ids = None
69
+
70
+ # ▶️ KoBART 2
71
+ #second_config = AutoConfig.from_pretrained("jeju-potato/jeju_potato_models", subfolder="second_kobart_v2")
72
+ #second_config.num_labels = 2
73
+ #second_config.id2label = {0: "NEGATIVE", 1: "POSITIVE"}
74
+ #second_config.label2id = {"NEGATIVE": 0, "POSITIVE": 1}
75
+ second_tokenizer = PreTrainedTokenizerFast.from_pretrained("jeju-potato/jeju_potato_models", subfolder="second_kobart_v2")
76
+ second_model = BartForConditionalGeneration.from_pretrained("jeju-potato/jeju_potato_models", subfolder="second_kobart_v2").to(device)
77
+ second_model.config.forced_decoder_ids = None
78
+ second_model.generation_config.forced_decoder_ids = None
79
+
80
+ # ▶️ Whisper inference
81
+ def inference_whisper(audio_path):
82
+ waveform, sr = torchaudio.load(audio_path)
83
+ waveform = waveform.mean(dim=0) if waveform.shape[0] > 1 else waveform.squeeze(0)
84
+ inputs = whisper_processor(
85
+ waveform.numpy(),
86
+ sampling_rate=sr,
87
+ return_tensors="pt",
88
+ task="transcribe"
89
+ ).to(device)
90
+ whisper_model.generation_config.forced_decoder_ids = None
91
+ whisper_model.generation_config.suppress_tokens = []
92
+
93
+ with torch.no_grad():
94
+ pred_ids = whisper_model.generate(inputs["input_features"], max_length=128)
95
+ decoded = whisper_processor.batch_decode(pred_ids, skip_special_tokens=True)[0]
96
+ return decoded
97
+
98
+ # ▶️ KoBART inference
99
+ def inference_kobart(text, model, tokenizer):
100
+ encoded = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=128)
101
+ input_ids = encoded["input_ids"].to(device)
102
+ attention_mask = encoded["attention_mask"].to(device)
103
+
104
+ with torch.no_grad():
105
+ pred_ids = model.generate(input_ids=input_ids, attention_mask=attention_mask, max_length=128)
106
+ result = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)[0]
107
+ return result
108
+
109
+ # ▶️ 전체 파이프라인
110
+ def run_pipeline(audio_path):
111
+ whisper_text = inference_whisper(audio_path)
112
+ first_kobart_text = inference_kobart(whisper_text, first_model, first_tokenizer)
113
+ second_kobart_text = inference_kobart(first_kobart_text, second_model, second_tokenizer)
114
+ return whisper_text, first_kobart_text, second_kobart_text
115
+
116
+ # ▶️ test_random
117
+ def test_random_file():
118
+ audio_dir = r"D:\\whisper\\audio1_all"
119
+ candidates = [f for f in os.listdir(audio_dir) if f.endswith(".wav")]
120
+ if not candidates:
121
+ return {"error": "No .wav files found in data/source_data"}
122
+ chosen_file = random.choice(candidates)
123
+ audio_path = os.path.join(audio_dir, chosen_file)
124
+ try:
125
+ whisper_text, first_kobart_text, second_kobart_text = run_pipeline(audio_path)
126
+ except Exception as e:
127
+ return {"error": str(e)}
128
+ return {
129
+ "filename": chosen_file,
130
+ "whisper_result": whisper_text,
131
+ "first_kobart_result": first_kobart_text,
132
+ "second_kobart_result": second_kobart_text
133
+ }
134
+
135
+ # ▶️ POST /inference
136
+ @app.post("/inference")
137
+ async def inference(audio: UploadFile = File(...)):
138
+ os.makedirs("temp", exist_ok=True)
139
+ input_path = f"temp/{audio.filename}"
140
+ with open(input_path, "wb") as f:
141
+ f.write(await audio.read())
142
+
143
+ output_path = "temp/converted.wav"
144
+ try:
145
+ sound = AudioSegment.from_file(input_path)
146
+ sound = sound.set_frame_rate(16000).set_channels(1)
147
+ sound.export(output_path, format="wav")
148
+ except Exception as e:
149
+ return JSONResponse(status_code=400, content={"error": f"와이파 변환 실패: {e}"})
150
+
151
+ try:
152
+ whisper_text, first_kobart_text, second_kobart_text = run_pipeline(output_path)
153
+ except Exception as e:
154
+ return JSONResponse(status_code=500, content={"error": f"모델 추론 실패: {e}"})
155
+
156
+ return {
157
+ "whisper_result": whisper_text,
158
+ "first_kobart_result": first_kobart_text,
159
+ "second_kobart_result": second_kobart_text
160
+ }