start3406 commited on
Commit
b04f0ed
·
verified ·
1 Parent(s): 0d1a40e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +59 -29
app.py CHANGED
@@ -2,47 +2,77 @@ import gradio as gr
2
  import os
3
  from openai import OpenAI
4
 
5
- OPENAI_API_KEY = os.getenv('OPENAI_API_KEY')
6
- openai_client = OpenAI(api_key=OPENAI_API_KEY)
7
-
8
  DEEPSEEK_API_KEY = os.getenv("DEEPSEEK_API_KEY")
9
- deepseek_client = OpenAI(api_key=DEEPSEEK_API_KEY, base_url="https://api.deepseek.com")
10
 
11
- def generate_response(prompt, model_provider, temperature, top_p, max_tokens, repetition_penalty):
12
- models = {
13
- "DeepSeek": ("deepseek-chat", deepseek_client),
14
- "OpenAI": ("gpt-3.5-turbo", openai_client)
 
 
 
 
 
 
 
 
 
15
  }
16
- model, client = models[model_provider]
17
 
18
  try:
19
- response = client.chat.completions.create(
20
  model=model,
21
  messages=[{"role": "user", "content": prompt}],
22
  temperature=temperature,
23
  top_p=top_p,
24
  max_tokens=max_tokens,
25
- presence_penalty=repetition_penalty,
26
- stream=False
 
27
  )
28
- return response.choices[0].message.content.strip()
29
  except Exception as e:
30
- return f"{model_provider} API Error: {str(e)}"
31
 
32
- iface = gr.Interface(
33
- fn=generate_response,
34
- inputs=[
35
- gr.Dropdown(choices=["DeepSeek", "OpenAI"], value="DeepSeek", label="Model Provider",interactive=False),
36
- gr.Textbox(label="Prompt", lines=6, placeholder="Ask something..."),
37
- gr.Slider(minimum=0.1, maximum=1.5, value=0.7, step=0.1, label="Temperature"),
38
- gr.Slider(minimum=0.1, maximum=1.0, value=0.9, step=0.05, label="Top-p"),
39
- gr.Slider(minimum=32, maximum=2048, value=512, step=32, label="Max New Tokens"),
40
- gr.Slider(minimum=1.0, maximum=2.0, value=1.1, step=0.1, label="Repetition Penalty")
41
- ],
42
- outputs="text",
43
- title="🧠 DeepSeek LLM Chat with Parameter Tuning",
44
- theme=gr.themes.Soft()
45
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
 
47
  iface.launch()
48
-
 
2
  import os
3
  from openai import OpenAI
4
 
5
+ # 从环境变量读取两个 API Key
6
+ OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
 
7
  DEEPSEEK_API_KEY = os.getenv("DEEPSEEK_API_KEY")
 
8
 
9
+ # 初始化两个客户端
10
+ openai_client = OpenAI(api_key=OPENAI_API_KEY)
11
+ deepseek_client = OpenAI(
12
+ api_key=DEEPSEEK_API_KEY,
13
+ # 如果 DeepSeek 的 API 路径需要带 /v1,可以根据实际文档调整
14
+ base_url="https://api.deepseek.com/v1"
15
+ )
16
+
17
+ def generate_response(model_provider, prompt, temperature, top_p, max_tokens, repetition_penalty):
18
+ # 根据 model_provider 分发到对应 client 和 model 名称
19
+ clients = {
20
+ "DeepSeek": (deepseek_client, "deepseek-chat"),
21
+ "OpenAI": (openai_client, "gpt-3.5-turbo")
22
  }
23
+ client, model = clients[model_provider]
24
 
25
  try:
26
+ resp = client.chat.completions.create(
27
  model=model,
28
  messages=[{"role": "user", "content": prompt}],
29
  temperature=temperature,
30
  top_p=top_p,
31
  max_tokens=max_tokens,
32
+ # repetition_penalty 建议映射到 frequency_penalty 或 presence_penalty,根据需求选一个
33
+ frequency_penalty=repetition_penalty,
34
+ presence_penalty=0.0
35
  )
36
+ return resp.choices[0].message.content.strip()
37
  except Exception as e:
38
+ return f"{model_provider} API Error: {e}"
39
 
40
+ with gr.Blocks(theme=gr.themes.Soft()) as iface:
41
+ gr.Markdown("## 🧠 DeepSeek / OpenAI 聊天演示(可调参)")
42
+ with gr.Row():
43
+ provider = gr.Dropdown(
44
+ choices=["DeepSeek", "OpenAI"],
45
+ value="DeepSeek",
46
+ label="模型供应商"
47
+ )
48
+ temperature = gr.Slider(
49
+ minimum=0.1, maximum=1.5, step=0.1, value=0.7,
50
+ label="Temperature"
51
+ )
52
+ top_p = gr.Slider(
53
+ minimum=0.1, maximum=1.0, step=0.05, value=0.9,
54
+ label="Top-p"
55
+ )
56
+ prompt = gr.Textbox(
57
+ label="Prompt",
58
+ lines=6,
59
+ placeholder="在这里输入你的问题……"
60
+ )
61
+ with gr.Row():
62
+ max_tokens = gr.Slider(
63
+ minimum=32, maximum=2048, step=32, value=512,
64
+ label="Max Tokens"
65
+ )
66
+ rep_penalty = gr.Slider(
67
+ minimum=0.0, maximum=2.0, step=0.1, value=1.1,
68
+ label="Frequency Penalty"
69
+ )
70
+ output = gr.Textbox(label="Response")
71
+ btn = gr.Button("生成回答")
72
+ btn.click(
73
+ fn=generate_response,
74
+ inputs=[provider, prompt, temperature, top_p, max_tokens, rep_penalty],
75
+ outputs=output
76
+ )
77
 
78
  iface.launch()