dprat0821 commited on
Commit
f7930f4
Β·
verified Β·
1 Parent(s): 5ec4fbc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -6
app.py CHANGED
@@ -1,6 +1,6 @@
1
  import gradio as gr
2
  import os
3
- import openai
4
  from transformers import AutoTokenizer, AutoModelForCausalLM
5
  import torch
6
 
@@ -8,8 +8,8 @@ import torch
8
  DEEPSEEK_API_KEY = os.getenv("DEEPSEEK_API_KEY")
9
  OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
10
 
11
- # Initialize OpenAI client
12
- openai.api_key = OPENAI_API_KEY
13
 
14
  # Load DeepSeek model
15
  deepseek_model_id = "deepseek-ai/deepseek-llm-7b-chat"
@@ -32,36 +32,42 @@ def generate_response(prompt, model_provider, temperature, top_p, max_tokens, re
32
  repetition_penalty=repetition_penalty
33
  )
34
  return tokenizer.decode(outputs[0], skip_special_tokens=True)
 
35
  elif model_provider == "OpenAI":
36
  try:
37
- response = openai.ChatCompletion.create(
38
- model="gpt-3.5-turbo", # or another model of your choice
39
  messages=[{"role": "user", "content": prompt}],
40
  temperature=temperature,
41
  top_p=top_p,
42
  max_tokens=max_tokens,
43
  presence_penalty=repetition_penalty
44
  )
45
- return response.choices[0].message["content"].strip()
46
  except Exception as e:
47
  return f"OpenAI API Error: {str(e)}"
 
48
  else:
49
  return "Invalid model provider selected."
50
 
51
  with gr.Blocks() as demo:
52
  gr.Markdown("## πŸ” LLM Chat Interface")
 
53
  with gr.Row():
54
  model_provider = gr.Dropdown(
55
  choices=["DeepSeek", "OpenAI"],
56
  value="DeepSeek",
57
  label="Select Model Provider"
58
  )
 
59
  prompt = gr.Textbox(label="Enter your prompt", lines=4, placeholder="Type your message here...")
 
60
  with gr.Accordion("Advanced Settings", open=False):
61
  temperature = gr.Slider(0.1, 1.5, value=0.7, step=0.1, label="Temperature")
62
  top_p = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-p")
63
  max_tokens = gr.Slider(32, 2048, value=512, step=32, label="Max New Tokens")
64
  repetition_penalty = gr.Slider(1.0, 2.0, value=1.1, step=0.1, label="Repetition Penalty")
 
65
  output = gr.Textbox(label="Response")
66
  submit = gr.Button("Generate")
67
 
 
1
  import gradio as gr
2
  import os
3
+ from openai import OpenAI
4
  from transformers import AutoTokenizer, AutoModelForCausalLM
5
  import torch
6
 
 
8
  DEEPSEEK_API_KEY = os.getenv("DEEPSEEK_API_KEY")
9
  OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
10
 
11
+ # Initialize OpenAI client (for openai>=1.0.0)
12
+ client = OpenAI(api_key=OPENAI_API_KEY)
13
 
14
  # Load DeepSeek model
15
  deepseek_model_id = "deepseek-ai/deepseek-llm-7b-chat"
 
32
  repetition_penalty=repetition_penalty
33
  )
34
  return tokenizer.decode(outputs[0], skip_special_tokens=True)
35
+
36
  elif model_provider == "OpenAI":
37
  try:
38
+ response = client.chat.completions.create(
39
+ model="gpt-3.5-turbo", # or "gpt-4" if you have access
40
  messages=[{"role": "user", "content": prompt}],
41
  temperature=temperature,
42
  top_p=top_p,
43
  max_tokens=max_tokens,
44
  presence_penalty=repetition_penalty
45
  )
46
+ return response.choices[0].message.content.strip()
47
  except Exception as e:
48
  return f"OpenAI API Error: {str(e)}"
49
+
50
  else:
51
  return "Invalid model provider selected."
52
 
53
  with gr.Blocks() as demo:
54
  gr.Markdown("## πŸ” LLM Chat Interface")
55
+
56
  with gr.Row():
57
  model_provider = gr.Dropdown(
58
  choices=["DeepSeek", "OpenAI"],
59
  value="DeepSeek",
60
  label="Select Model Provider"
61
  )
62
+
63
  prompt = gr.Textbox(label="Enter your prompt", lines=4, placeholder="Type your message here...")
64
+
65
  with gr.Accordion("Advanced Settings", open=False):
66
  temperature = gr.Slider(0.1, 1.5, value=0.7, step=0.1, label="Temperature")
67
  top_p = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-p")
68
  max_tokens = gr.Slider(32, 2048, value=512, step=32, label="Max New Tokens")
69
  repetition_penalty = gr.Slider(1.0, 2.0, value=1.1, step=0.1, label="Repetition Penalty")
70
+
71
  output = gr.Textbox(label="Response")
72
  submit = gr.Button("Generate")
73