Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, Request, Form | |
| from fastapi.responses import HTMLResponse | |
| from pydantic import BaseModel | |
| import yaml | |
| from model.llama3 import LLaMA3 | |
| # ---------------- ์ค์ ๋ก๋ ---------------- | |
| with open("config/llama3.yaml", "r") as f: | |
| config = yaml.safe_load(f) | |
| # ---------------- ๋ชจ๋ธ ์ด๊ธฐํ ---------------- | |
| llama3 = LLaMA3(config) | |
| # ---------------- FastAPI ์ฑ ---------------- | |
| app = FastAPI( | |
| title="Korean Pronunciation Correction API", | |
| description="FastAPI + LLaMA3 ๊ธฐ๋ฐ ๋ฐ์ ๊ต์ ์๋ฒ", | |
| version="1.0.0" | |
| ) | |
| # ---------------- ์ ๋ ฅ ๋ชจ๋ธ ---------------- | |
| class InputData(BaseModel): | |
| user_input: str | |
| correct_input: str | |
| # ---------------- API: JSON POST ---------------- | |
| async def generate_correction(data: InputData): | |
| result = llama3.generate(data.user_input, data.correct_input) | |
| return {"result": result} | |
| # ---------------- HTML UI ---------------- | |
| async def form_ui(): | |
| return """ | |
| <html> | |
| <head> | |
| <title>Korean Pronunciation Correction</title> | |
| </head> | |
| <body style="font-family: sans-serif; max-width: 600px; margin: auto; padding: 2rem;"> | |
| <h1>๐ฃ๏ธ ๋ฐ์ ๊ต์ ํ ์คํธ</h1> | |
| <form action="/submit" method="post"> | |
| <label for="user_input">๐ง ์๋ชป๋ ๋ฐ์:</label><br> | |
| <input type="text" id="user_input" name="user_input" value="๋ฐ๋ผ" required><br><br> | |
| <label for="correct_input">๐ฏ ์ฌ๋ฐ๋ฅธ ๋ฐ์:</label><br> | |
| <input type="text" id="correct_input" name="correct_input" value="๋ฐ๋ผ" required><br><br> | |
| <input type="submit" value="๊ต์ ์คํ"> | |
| </form> | |
| </body> | |
| </html> | |
| """ | |
| # ---------------- ๊ฒฐ๊ณผ ๋ ๋๋ง ---------------- | |
| from fastapi.responses import HTMLResponse | |
| from fastapi import Form, Request | |
| import traceback | |
| async def handle_form(request: Request, user_input: str = Form(...), correct_input: str = Form(...)): | |
| try: | |
| result = llama3.generate(user_input, correct_input) | |
| except Exception as e: | |
| error_details = traceback.format_exc() | |
| return f""" | |
| <html> | |
| <head><title>์๋ฌ</title></head> | |
| <body style="font-family: sans-serif; max-width: 600px; margin: auto; padding: 2rem;"> | |
| <h1>โ ์๋ฒ ์ค๋ฅ ๋ฐ์</h1> | |
| <p><strong>์ค๋ฅ ๋ฉ์์ง:</strong></p> | |
| <pre>{str(e)}</pre> | |
| <hr> | |
| <p><strong>์๋ฌ ์์ธ:</strong></p> | |
| <pre>{error_details}</pre> | |
| <br> | |
| <a href="/">๋์๊ฐ๊ธฐ</a> | |
| </body> | |
| </html> | |
| """ | |
| return f""" | |
| <html> | |
| <head><title>๊ฒฐ๊ณผ</title></head> | |
| <body style="font-family: sans-serif; max-width: 600px; margin: auto; padding: 2rem;"> | |
| <h1>โ ๊ฒฐ๊ณผ</h1> | |
| <p><strong>์ ๋ ฅ๋ ๋ฐ์:</strong> {user_input}</p> | |
| <p><strong>์ ๋ต ๋ฐ์:</strong> {correct_input}</p> | |
| <hr> | |
| <h2>๐ง ๋ชจ๋ธ ์๋ต:</h2> | |
| <pre>{result}</pre> | |
| <br> | |
| <a href="/">๋ค์ ์๋ํ๊ธฐ</a> | |
| </body> | |
| </html> | |
| """ | |
| # ---------------- ํฌ์ค ์ฒดํฌ ---------------- | |
| async def health_check(): | |
| return { | |
| "status": "ok", | |
| "model": config["model"]["id"], | |
| "device": config["model"]["device"] | |
| } | |