dprat0821 commited on
Commit
29745d3
·
verified ·
1 Parent(s): df51c66

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +74 -0
app.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+ import openai
4
+ from transformers import AutoTokenizer, AutoModelForCausalLM
5
+ import torch
6
+
7
+ # Set your API keys as environment variables or replace os.getenv with your actual keys
8
+ DEEPSEEK_API_KEY = os.getenv("DEEPSEEK_API_KEY")
9
+ OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
10
+
11
+ # Initialize OpenAI client
12
+ openai.api_key = OPENAI_API_KEY
13
+
14
+ # Load DeepSeek model
15
+ deepseek_model_id = "deepseek-ai/deepseek-llm-7b-chat"
16
+ tokenizer = AutoTokenizer.from_pretrained(deepseek_model_id)
17
+ deepseek_model = AutoModelForCausalLM.from_pretrained(
18
+ deepseek_model_id,
19
+ torch_dtype=torch.float16,
20
+ device_map="auto"
21
+ )
22
+
23
+ def generate_response(prompt, model_provider, temperature, top_p, max_tokens, repetition_penalty):
24
+ if model_provider == "DeepSeek":
25
+ inputs = tokenizer(prompt, return_tensors="pt").to(deepseek_model.device)
26
+ outputs = deepseek_model.generate(
27
+ **inputs,
28
+ do_sample=True,
29
+ temperature=temperature,
30
+ top_p=top_p,
31
+ max_new_tokens=max_tokens,
32
+ repetition_penalty=repetition_penalty
33
+ )
34
+ return tokenizer.decode(outputs[0], skip_special_tokens=True)
35
+ elif model_provider == "OpenAI":
36
+ try:
37
+ response = openai.ChatCompletion.create(
38
+ model="gpt-3.5-turbo", # or another model of your choice
39
+ messages=[{"role": "user", "content": prompt}],
40
+ temperature=temperature,
41
+ top_p=top_p,
42
+ max_tokens=max_tokens,
43
+ presence_penalty=repetition_penalty
44
+ )
45
+ return response.choices[0].message["content"].strip()
46
+ except Exception as e:
47
+ return f"OpenAI API Error: {str(e)}"
48
+ else:
49
+ return "Invalid model provider selected."
50
+
51
+ with gr.Blocks() as demo:
52
+ gr.Markdown("## 🔍 LLM Chat Interface")
53
+ with gr.Row():
54
+ model_provider = gr.Dropdown(
55
+ choices=["DeepSeek", "OpenAI"],
56
+ value="DeepSeek",
57
+ label="Select Model Provider"
58
+ )
59
+ prompt = gr.Textbox(label="Enter your prompt", lines=4, placeholder="Type your message here...")
60
+ with gr.Accordion("Advanced Settings", open=False):
61
+ temperature = gr.Slider(0.1, 1.5, value=0.7, step=0.1, label="Temperature")
62
+ top_p = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-p")
63
+ max_tokens = gr.Slider(32, 2048, value=512, step=32, label="Max New Tokens")
64
+ repetition_penalty = gr.Slider(1.0, 2.0, value=1.1, step=0.1, label="Repetition Penalty")
65
+ output = gr.Textbox(label="Response")
66
+ submit = gr.Button("Generate")
67
+
68
+ submit.click(
69
+ fn=generate_response,
70
+ inputs=[prompt, model_provider, temperature, top_p, max_tokens, repetition_penalty],
71
+ outputs=output
72
+ )
73
+
74
+ demo.launch()