bigeco commited on
Commit
80de494
ยท
verified ยท
1 Parent(s): c36e528

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +64 -143
app.py CHANGED
@@ -1,161 +1,82 @@
1
- """TODO: ๋ชจ๋ธ ํ”„๋ ˆ์ž„์›Œํฌ ๋Œ์•„๊ฐ€๋„๋ก ๊ตฌํ˜„
2
- """
3
-
4
- import logging
5
- from fastapi import FastAPI
6
  from fastapi.responses import HTMLResponse
7
- from fastapi.middleware.cors import CORSMiddleware
8
- import uvicorn
 
 
 
 
 
9
 
10
- # ๋กœ๊น… ์„ค์ •
11
- logging.basicConfig(level=logging.INFO)
12
- logger = logging.getLogger(__name__)
13
 
14
- # FastAPI ์•ฑ ์ƒ์„ฑ
15
  app = FastAPI(
16
- title="Korean Speech Recognition API",
17
- description="Real-time Korean speech recognition using Whisper",
18
  version="1.0.0"
19
  )
20
 
21
- # CORS ์„ค์ • ์ถ”๊ฐ€
22
- app.add_middleware(
23
- CORSMiddleware,
24
- allow_origins=["*"],
25
- allow_credentials=True,
26
- allow_methods=["*"],
27
- allow_headers=["*"],
28
- )
29
 
 
 
 
 
 
 
 
30
  @app.get("/", response_class=HTMLResponse)
31
- async def root():
32
- """๊ธฐ๋ณธ ํŽ˜์ด์ง€"""
33
  return """
34
- <!DOCTYPE html>
35
  <html>
36
- <head>
37
- <title>Korean Speech Recognition</title>
38
- <style>
39
- body { font-family: Arial, sans-serif; max-width: 800px; margin: 0 auto; padding: 20px; }
40
- .container { background: #f5f5f5; padding: 20px; border-radius: 8px; }
41
- .endpoint { background: white; padding: 10px; margin: 10px 0; border-radius: 4px; }
42
- button { background: #007bff; color: white; border: none; padding: 10px 20px; border-radius: 4px; cursor: pointer; }
43
- button:hover { background: #0056b3; }
44
- .status { margin: 10px 0; padding: 10px; border-radius: 4px; }
45
- .success { background: #d4edda; color: #155724; }
46
- .error { background: #f8d7da; color: #721c24; }
47
- .warning { background: #fff3cd; color: #856404; }
48
- </style>
49
- </head>
50
- <body>
51
- <div class="container">
52
- <h1>๐ŸŽค Korean Speech Recognition API</h1>
53
- <p>์‹ค์‹œ๊ฐ„ ํ•œ๊ตญ์–ด ์Œ์„ฑ ์ธ์‹ ์„œ๋น„์Šค๊ฐ€ ์‹คํ–‰ ์ค‘์ž…๋‹ˆ๋‹ค.</p>
54
-
55
- <h3>API ์—”๋“œํฌ์ธํŠธ:</h3>
56
- <div class="endpoint">WebSocket /ws - ์‹ค์‹œ๊ฐ„ ์Œ์„ฑ ์ธ์‹</div>
57
- <div class="endpoint">POST /transcribe - ํŒŒ์ผ ์—…๋กœ๋“œ ์Œ์„ฑ ์ธ์‹</div>
58
- <div class="endpoint">GET /health - ํ—ฌ์Šค ์ฒดํฌ</div>
59
-
60
- <h3>WebSocket ํ…Œ์ŠคํŠธ:</h3>
61
- <button onclick="testWebSocket()">WebSocket ์—ฐ๊ฒฐ ํ…Œ์ŠคํŠธ</button>
62
- <div id="status"></div>
63
-
64
- <script>
65
- function testWebSocket() {
66
- const status = document.getElementById('status');
67
- status.innerHTML = '<div class="status warning">์—ฐ๊ฒฐ ์‹œ๋„ ์ค‘...</div>';
68
-
69
- try {
70
- const protocol = window.location.protocol === 'https:' ? 'wss:' : 'ws:';
71
- const wsUrl = `${protocol}//${window.location.host}/ws`;
72
- console.log('WebSocket URL:', wsUrl);
73
-
74
- const ws = new WebSocket(wsUrl);
75
- let messageReceived = false;
76
-
77
- ws.onopen = () => {
78
- console.log('WebSocket ์—ฐ๊ฒฐ๋จ');
79
- status.innerHTML = '<div class="status success">โœ… WebSocket ์—ฐ๊ฒฐ ์„ฑ๊ณต! ๋ฉ”์‹œ์ง€ ์ „์†ก ์ค‘...</div>';
80
-
81
- // ํ…์ŠคํŠธ ๋ฉ”์‹œ์ง€ ์ „์†ก
82
- ws.send('ping');
83
- };
84
-
85
- ws.onmessage = (event) => {
86
- messageReceived = true;
87
- console.log('๋ฐ›์€ ๋ฉ”์‹œ์ง€:', event.data);
88
- try {
89
- const data = JSON.parse(event.data);
90
- status.innerHTML = '<div class="status success">โœ… ์„œ๋ฒ„ ์‘๋‹ต ์ˆ˜์‹  ์„ฑ๊ณต!</div>';
91
-
92
- // 3์ดˆ ํ›„ ์ •์ƒ์ ์œผ๋กœ ์—ฐ๊ฒฐ ์ข…๋ฃŒ
93
- setTimeout(() => {
94
- if (ws.readyState === WebSocket.OPEN) {
95
- ws.close(1000, 'Test completed successfully');
96
- }
97
- }, 3000);
98
- } catch (e) {
99
- status.innerHTML = '<div class="status success">โœ… ํ…์ŠคํŠธ ๋ฉ”์‹œ์ง€ ์ˆ˜์‹ : ' + event.data + '</div>';
100
- setTimeout(() => {
101
- if (ws.readyState === WebSocket.OPEN) {
102
- ws.close(1000, 'Test completed successfully');
103
- }
104
- }, 3000);
105
- }
106
- };
107
-
108
- ws.onerror = (error) => {
109
- console.error('WebSocket ์˜ค๋ฅ˜:', error);
110
- status.innerHTML = '<div class="status error">โŒ WebSocket ์—ฐ๊ฒฐ ์‹คํŒจ</div>';
111
- };
112
-
113
- ws.onclose = (event) => {
114
- console.log('WebSocket ์—ฐ๊ฒฐ ์ข…๋ฃŒ', event.code, event.reason);
115
-
116
- if (event.code === 1000) {
117
- status.innerHTML = '<div class="status success">โœ… WebSocket ํ…Œ์ŠคํŠธ ์™„๋ฃŒ (์ •์ƒ ์ข…๋ฃŒ)</div>';
118
- } else if (messageReceived) {
119
- status.innerHTML = '<div class="status success">โœ… ๋ฉ”์‹œ์ง€ ๊ตํ™˜ ์„ฑ๊ณต (์ฝ”๋“œ: ' + event.code + ')</div>';
120
- } else {
121
- status.innerHTML = '<div class="status error">โŒ WebSocket ์—ฐ๊ฒฐ์ด ์˜ˆ๊ธฐ์น˜ ์•Š๊ฒŒ ์ข…๋ฃŒ๋จ (์ฝ”๋“œ: ' + event.code + ')</div>';
122
- }
123
- };
124
-
125
- // 10์ดˆ ํ›„ ํƒ€์ž„์•„์›ƒ
126
- setTimeout(() => {
127
- if (ws.readyState === WebSocket.CONNECTING || ws.readyState === WebSocket.OPEN) {
128
- if (!messageReceived) {
129
- status.innerHTML = '<div class="status error">โŒ ์‘๋‹ต ํƒ€์ž„์•„์›ƒ</div>';
130
- ws.close();
131
- }
132
- }
133
- }, 10000);
134
-
135
- } catch (error) {
136
- console.error('WebSocket ํ…Œ์ŠคํŠธ ์˜ค๋ฅ˜:', error);
137
- status.innerHTML = '<div class="status error">โŒ WebSocket ํ…Œ์ŠคํŠธ ์‹คํŒจ: ' + error.message + '</div>';
138
- }
139
- }
140
- </script>
141
- </div>
142
- </body>
143
  </html>
144
  """
145
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
146
  @app.get("/health")
147
  async def health_check():
148
- """ํ—ฌ์Šค ์ฒดํฌ"""
149
  return {
150
- "status": "healthy",
151
- "model": "?",
152
- "language": "korean",
153
- "version": "1.0.0",
154
  }
155
-
156
- @app.on_event("startup")
157
- async def startup_event():
158
- logger.info("๐Ÿš€ Korean Speech Recognition API ์„œ๋ฒ„ ์‹œ์ž‘๋จ")
159
-
160
- if __name__ == "__main__":
161
- uvicorn.run(app, host="0.0.0.0", port=7860)
 
1
+ from fastapi import FastAPI, Request, Form
 
 
 
 
2
  from fastapi.responses import HTMLResponse
3
+ from pydantic import BaseModel
4
+ import yaml
5
+ from model.llama3 import LLaMA3
6
+
7
+ # ---------------- ์„ค์ • ๋กœ๋“œ ----------------
8
+ with open("data/config/llama3.yaml", "r") as f:
9
+ config = yaml.safe_load(f)
10
 
11
+ # ---------------- ๋ชจ๋ธ ์ดˆ๊ธฐํ™” ----------------
12
+ llama3 = LLaMA3(config)
 
13
 
14
+ # ---------------- FastAPI ์•ฑ ----------------
15
  app = FastAPI(
16
+ title="Korean Pronunciation Correction API",
17
+ description="FastAPI + LLaMA3 ๊ธฐ๋ฐ˜ ๋ฐœ์Œ ๊ต์ • ์„œ๋ฒ„",
18
  version="1.0.0"
19
  )
20
 
21
+ # ---------------- ์ž…๋ ฅ ๋ชจ๋ธ ----------------
22
+ class InputData(BaseModel):
23
+ user_input: str
24
+ correct_input: str
 
 
 
 
25
 
26
+ # ---------------- API: JSON POST ----------------
27
+ @app.post("/generate")
28
+ async def generate_correction(data: InputData):
29
+ result = llama3.generate(data.user_input, data.correct_input)
30
+ return {"result": result}
31
+
32
+ # ---------------- HTML UI ----------------
33
  @app.get("/", response_class=HTMLResponse)
34
+ async def form_ui():
 
35
  return """
 
36
  <html>
37
+ <head>
38
+ <title>Korean Pronunciation Correction</title>
39
+ </head>
40
+ <body style="font-family: sans-serif; max-width: 600px; margin: auto; padding: 2rem;">
41
+ <h1>๐Ÿ—ฃ๏ธ ๋ฐœ์Œ ๊ต์ • ํ…Œ์ŠคํŠธ</h1>
42
+ <form action="/submit" method="post">
43
+ <label for="user_input">๐Ÿง ์ž˜๋ชป๋œ ๋ฐœ์Œ:</label><br>
44
+ <input type="text" id="user_input" name="user_input" value="๋ฐ•๋ผ" required><br><br>
45
+
46
+ <label for="correct_input">๐ŸŽฏ ์˜ฌ๋ฐ”๋ฅธ ๋ฐœ์Œ:</label><br>
47
+ <input type="text" id="correct_input" name="correct_input" value="๋ฐœ๋ผ" required><br><br>
48
+
49
+ <input type="submit" value="๊ต์ • ์‹คํ–‰">
50
+ </form>
51
+ </body>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
  </html>
53
  """
54
 
55
+ # ---------------- ๊ฒฐ๊ณผ ๋ Œ๋”๋ง ----------------
56
+ @app.post("/submit", response_class=HTMLResponse)
57
+ async def handle_form(user_input: str = Form(...), correct_input: str = Form(...)):
58
+ result = llama3.generate(user_input, correct_input)
59
+ return f"""
60
+ <html>
61
+ <head><title>๊ฒฐ๊ณผ</title></head>
62
+ <body style="font-family: sans-serif; max-width: 600px; margin: auto; padding: 2rem;">
63
+ <h1>โœ… ๊ฒฐ๊ณผ</h1>
64
+ <p><strong>์ž…๋ ฅ๋œ ๋ฐœ์Œ:</strong> {user_input}</p>
65
+ <p><strong>์ •๋‹ต ๋ฐœ์Œ:</strong> {correct_input}</p>
66
+ <hr>
67
+ <h2>๐Ÿง  ๋ชจ๋ธ ์‘๋‹ต:</h2>
68
+ <pre>{result}</pre>
69
+ <br>
70
+ <a href="/">๋‹ค์‹œ ์‹œ๋„ํ•˜๊ธฐ</a>
71
+ </body>
72
+ </html>
73
+ """
74
+
75
+ # ---------------- ํ—ฌ์Šค ์ฒดํฌ ----------------
76
  @app.get("/health")
77
  async def health_check():
 
78
  return {
79
+ "status": "ok",
80
+ "model": config["model"]["id"],
81
+ "device": config["model"]["device"]
 
82
  }