aiqtech commited on
Commit
6c7c2aa
Β·
verified Β·
1 Parent(s): 4b68f33

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +332 -0
app.py ADDED
@@ -0,0 +1,332 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import asyncio
4
+ from typing import Optional, List, Dict
5
+ from contextlib import asynccontextmanager
6
+
7
+ import requests
8
+ import uvicorn
9
+ from fastapi import FastAPI, HTTPException
10
+ from fastapi.middleware.cors import CORSMiddleware
11
+ from pydantic import BaseModel, Field
12
+ import gradio as gr
13
+
14
+
15
+ # Pydantic λͺ¨λΈ μ •μ˜
16
+ class Message(BaseModel):
17
+ role: str
18
+ content: str
19
+
20
+
21
+ class ChatRequest(BaseModel):
22
+ messages: List[Message]
23
+ model: str = "accounts/fireworks/models/qwen3-235b-a22b-instruct-2507"
24
+ max_tokens: int = Field(default=4096, ge=1, le=8192)
25
+ temperature: float = Field(default=0.6, ge=0, le=2)
26
+ top_p: float = Field(default=1.0, ge=0, le=1)
27
+ top_k: int = Field(default=40, ge=1, le=100)
28
+ presence_penalty: float = Field(default=0, ge=-2, le=2)
29
+ frequency_penalty: float = Field(default=0, ge=-2, le=2)
30
+
31
+
32
+ class ChatResponse(BaseModel):
33
+ response: str
34
+ model: str
35
+ tokens_used: Optional[int] = None
36
+
37
+
38
+ # Fireworks API ν΄λΌμ΄μ–ΈνŠΈ
39
+ class FireworksClient:
40
+ def __init__(self, api_key: Optional[str] = None):
41
+ self.api_key = api_key or os.getenv("FIREWORKS_API_KEY")
42
+ if not self.api_key:
43
+ raise ValueError("API key is required. Set FIREWORKS_API_KEY environment variable.")
44
+
45
+ self.base_url = "https://api.fireworks.ai/inference/v1/chat/completions"
46
+ self.headers = {
47
+ "Accept": "application/json",
48
+ "Content-Type": "application/json",
49
+ "Authorization": f"Bearer {self.api_key}"
50
+ }
51
+
52
+ def chat(self, request: ChatRequest) -> Dict:
53
+ """Fireworks API에 μ±„νŒ… μš”μ²­μ„ λ³΄λƒ…λ‹ˆλ‹€."""
54
+ payload = {
55
+ "model": request.model,
56
+ "max_tokens": request.max_tokens,
57
+ "top_p": request.top_p,
58
+ "top_k": request.top_k,
59
+ "presence_penalty": request.presence_penalty,
60
+ "frequency_penalty": request.frequency_penalty,
61
+ "temperature": request.temperature,
62
+ "messages": [msg.dict() for msg in request.messages]
63
+ }
64
+
65
+ try:
66
+ response = requests.post(
67
+ self.base_url,
68
+ headers=self.headers,
69
+ data=json.dumps(payload),
70
+ timeout=30
71
+ )
72
+ response.raise_for_status()
73
+ return response.json()
74
+ except requests.exceptions.RequestException as e:
75
+ raise HTTPException(status_code=500, detail=f"API request failed: {str(e)}")
76
+
77
+
78
+ # Gradio μ•± 생성
79
+ def create_gradio_app(client: FireworksClient):
80
+ """Gradio μΈν„°νŽ˜μ΄μŠ€λ₯Ό μƒμ„±ν•©λ‹ˆλ‹€."""
81
+
82
+ def chat_with_llm(
83
+ message: str,
84
+ history: List[List[str]],
85
+ model: str,
86
+ temperature: float,
87
+ max_tokens: int,
88
+ top_p: float,
89
+ top_k: int
90
+ ):
91
+ """Gradio μ±„νŒ… ν•¨μˆ˜"""
92
+ if not message:
93
+ return "", history
94
+
95
+ # λŒ€ν™” 기둝을 Message ν˜•μ‹μœΌλ‘œ λ³€ν™˜
96
+ messages = []
97
+ for user_msg, assistant_msg in history:
98
+ if user_msg:
99
+ messages.append(Message(role="user", content=user_msg))
100
+ if assistant_msg:
101
+ messages.append(Message(role="assistant", content=assistant_msg))
102
+
103
+ # ν˜„μž¬ λ©”μ‹œμ§€ μΆ”κ°€
104
+ messages.append(Message(role="user", content=message))
105
+
106
+ # API μš”μ²­
107
+ try:
108
+ request = ChatRequest(
109
+ messages=messages,
110
+ model=model,
111
+ temperature=temperature,
112
+ max_tokens=max_tokens,
113
+ top_p=top_p,
114
+ top_k=top_k
115
+ )
116
+
117
+ response = client.chat(request)
118
+
119
+ # μ‘λ‹΅μ—μ„œ ν…μŠ€νŠΈ μΆ”μΆœ
120
+ if "choices" in response and len(response["choices"]) > 0:
121
+ assistant_response = response["choices"][0]["message"]["content"]
122
+ else:
123
+ assistant_response = "응닡을 받을 수 μ—†μŠ΅λ‹ˆλ‹€."
124
+
125
+ # νžˆμŠ€ν† λ¦¬ μ—…λ°μ΄νŠΈ
126
+ history.append([message, assistant_response])
127
+ return "", history
128
+
129
+ except Exception as e:
130
+ error_msg = f"였λ₯˜ λ°œμƒ: {str(e)}"
131
+ history.append([message, error_msg])
132
+ return "", history
133
+
134
+ # Gradio μΈν„°νŽ˜μ΄μŠ€ ꡬ성
135
+ with gr.Blocks(title="LLM Chat Interface") as demo:
136
+ gr.Markdown("# πŸš€ Fireworks LLM Chat Interface")
137
+ gr.Markdown("Qwen3-235B λͺ¨λΈμ„ μ‚¬μš©ν•œ μ±„νŒ… μΈν„°νŽ˜μ΄μŠ€μž…λ‹ˆλ‹€.")
138
+
139
+ with gr.Row():
140
+ with gr.Column(scale=3):
141
+ chatbot = gr.Chatbot(
142
+ height=500,
143
+ label="μ±„νŒ… μ°½"
144
+ )
145
+ msg = gr.Textbox(
146
+ label="λ©”μ‹œμ§€ μž…λ ₯",
147
+ placeholder="λ©”μ‹œμ§€λ₯Ό μž…λ ₯ν•˜μ„Έμš”...",
148
+ lines=2
149
+ )
150
+ with gr.Row():
151
+ submit = gr.Button("전솑", variant="primary")
152
+ clear = gr.Button("λŒ€ν™” μ΄ˆκΈ°ν™”")
153
+
154
+ with gr.Column(scale=1):
155
+ gr.Markdown("### βš™οΈ μ„€μ •")
156
+ model = gr.Textbox(
157
+ label="λͺ¨λΈ",
158
+ value="accounts/fireworks/models/qwen3-235b-a22b-instruct-2507",
159
+ interactive=True
160
+ )
161
+ temperature = gr.Slider(
162
+ minimum=0,
163
+ maximum=2,
164
+ value=0.6,
165
+ step=0.1,
166
+ label="Temperature"
167
+ )
168
+ max_tokens = gr.Slider(
169
+ minimum=100,
170
+ maximum=8192,
171
+ value=4096,
172
+ step=100,
173
+ label="Max Tokens"
174
+ )
175
+ top_p = gr.Slider(
176
+ minimum=0,
177
+ maximum=1,
178
+ value=1.0,
179
+ step=0.1,
180
+ label="Top P"
181
+ )
182
+ top_k = gr.Slider(
183
+ minimum=1,
184
+ maximum=100,
185
+ value=40,
186
+ step=1,
187
+ label="Top K"
188
+ )
189
+
190
+ # 이벀트 ν•Έλ“€λŸ¬
191
+ submit.click(
192
+ chat_with_llm,
193
+ inputs=[msg, chatbot, model, temperature, max_tokens, top_p, top_k],
194
+ outputs=[msg, chatbot]
195
+ )
196
+
197
+ msg.submit(
198
+ chat_with_llm,
199
+ inputs=[msg, chatbot, model, temperature, max_tokens, top_p, top_k],
200
+ outputs=[msg, chatbot]
201
+ )
202
+
203
+ clear.click(lambda: None, None, chatbot, queue=False)
204
+
205
+ return demo
206
+
207
+
208
+ # FastAPI μ•± μ„€μ •
209
+ @asynccontextmanager
210
+ async def lifespan(app: FastAPI):
211
+ """μ•± μ‹œμž‘/μ’…λ£Œ μ‹œ μ‹€ν–‰λ˜λŠ” ν•¨μˆ˜"""
212
+ # μ‹œμž‘ μ‹œ
213
+ print("πŸš€ Starting FastAPI + Gradio server...")
214
+ yield
215
+ # μ’…λ£Œ μ‹œ
216
+ print("πŸ‘‹ Shutting down server...")
217
+
218
+
219
+ app = FastAPI(
220
+ title="LLM API with Gradio Interface",
221
+ description="Fireworks LLM API with Gradio testing interface",
222
+ version="1.0.0",
223
+ lifespan=lifespan
224
+ )
225
+
226
+ # CORS μ„€μ •
227
+ app.add_middleware(
228
+ CORSMiddleware,
229
+ allow_origins=["*"],
230
+ allow_credentials=True,
231
+ allow_methods=["*"],
232
+ allow_headers=["*"],
233
+ )
234
+
235
+ # Fireworks ν΄λΌμ΄μ–ΈνŠΈ μ΄ˆκΈ°ν™”
236
+ try:
237
+ fireworks_client = FireworksClient()
238
+ except ValueError as e:
239
+ print(f"⚠️ Warning: {e}")
240
+ print("API endpoints will not work without a valid API key.")
241
+ fireworks_client = None
242
+
243
+
244
+ # API μ—”λ“œν¬μΈνŠΈ
245
+ @app.get("/")
246
+ async def root():
247
+ """루트 μ—”λ“œν¬μΈνŠΈ"""
248
+ return {
249
+ "message": "LLM API Server is running",
250
+ "endpoints": {
251
+ "api": "/chat",
252
+ "gradio": "/gradio",
253
+ "docs": "/docs"
254
+ }
255
+ }
256
+
257
+
258
+ @app.post("/chat", response_model=ChatResponse)
259
+ async def chat(request: ChatRequest):
260
+ """μ±„νŒ… API μ—”λ“œν¬μΈνŠΈ"""
261
+ if not fireworks_client:
262
+ raise HTTPException(status_code=500, detail="API key not configured")
263
+
264
+ try:
265
+ response = fireworks_client.chat(request)
266
+
267
+ # 응닡 νŒŒμ‹±
268
+ if "choices" in response and len(response["choices"]) > 0:
269
+ content = response["choices"][0]["message"]["content"]
270
+ tokens = response.get("usage", {}).get("total_tokens")
271
+
272
+ return ChatResponse(
273
+ response=content,
274
+ model=request.model,
275
+ tokens_used=tokens
276
+ )
277
+ else:
278
+ raise HTTPException(status_code=500, detail="Invalid response from API")
279
+
280
+ except HTTPException:
281
+ raise
282
+ except Exception as e:
283
+ raise HTTPException(status_code=500, detail=str(e))
284
+
285
+
286
+ @app.get("/health")
287
+ async def health_check():
288
+ """ν—¬μŠ€ 체크 μ—”λ“œν¬μΈνŠΈ"""
289
+ return {
290
+ "status": "healthy",
291
+ "api_configured": fireworks_client is not None
292
+ }
293
+
294
+
295
+ # Gradio μ•± 마운트
296
+ if fireworks_client:
297
+ gradio_app = create_gradio_app(fireworks_client)
298
+ app = gr.mount_gradio_app(app, gradio_app, path="/gradio")
299
+
300
+
301
+ # 메인 μ‹€ν–‰
302
+ if __name__ == "__main__":
303
+ import sys
304
+
305
+ # API ν‚€ 확인
306
+ if not os.getenv("FIREWORKS_API_KEY"):
307
+ print("⚠️ κ²½κ³ : FIREWORKS_API_KEY ν™˜κ²½λ³€μˆ˜κ°€ μ„€μ •λ˜μ§€ μ•Šμ•˜μŠ΅λ‹ˆλ‹€.")
308
+ print("μ„€μ • 방법:")
309
+ print(" Linux/Mac: export FIREWORKS_API_KEY='your-api-key'")
310
+ print(" Windows: set FIREWORKS_API_KEY=your-api-key")
311
+ print("")
312
+
313
+ # μ„ νƒμ μœΌλ‘œ API ν‚€ μž…λ ₯λ°›κΈ°
314
+ api_key = input("API ν‚€λ₯Ό μž…λ ₯ν•˜μ„Έμš” (Enterλ₯Ό λˆ„λ₯΄λ©΄ κ±΄λ„ˆλœλ‹ˆλ‹€): ").strip()
315
+ if api_key:
316
+ os.environ["FIREWORKS_API_KEY"] = api_key
317
+ fireworks_client = FireworksClient(api_key)
318
+ gradio_app = create_gradio_app(fireworks_client)
319
+ app = gr.mount_gradio_app(app, gradio_app, path="/gradio")
320
+
321
+ # μ„œλ²„ μ‹œμž‘
322
+ print("\nπŸš€ μ„œλ²„λ₯Ό μ‹œμž‘ν•©λ‹ˆλ‹€...")
323
+ print("πŸ“ API λ¬Έμ„œ: http://localhost:8000/docs")
324
+ print("πŸ’¬ Gradio UI: http://localhost:8000/gradio")
325
+ print("πŸ”§ API μ—”λ“œν¬μΈνŠΈ: http://localhost:8000/chat")
326
+
327
+ uvicorn.run(
328
+ app,
329
+ host="0.0.0.0",
330
+ port=8000,
331
+ reload=False
332
+ )