File size: 2,913 Bytes
9c52a51
a156bf2
762ce3f
6e9af51
487cc1b
6e9af51
762ce3f
 
 
632c528
762ce3f
6e9af51
762ce3f
632c528
 
6e9af51
632c528
6e9af51
632c528
 
 
cd848e7
632c528
cd848e7
632c528
6e9af51
632c528
 
 
6e9af51
0a4832a
632c528
 
6e9af51
 
0a4832a
 
 
 
 
 
 
 
632c528
 
6e9af51
632c528
 
 
 
762ce3f
632c528
762ce3f
632c528
762ce3f
632c528
 
6e9af51
632c528
 
 
 
 
 
 
0a4832a
 
 
632c528
 
 
 
 
0a4832a
 
632c528
 
0a4832a
 
 
632c528
0a4832a
632c528
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import os
import functools
from openai import OpenAI
import tiktoken
import gradio as gr

# Configure API client
api_key = os.getenv("OPENAI_API_KEY")
if not api_key:
    raise ValueError("Please set the OPENAI_API_KEY environment variable.")
client = OpenAI(api_key=api_key)

# Available models (comma-separated in env or fallback)
_env_models = os.getenv("OPENAI_MODEL_LIST", "gpt-3.5-turbo,gpt-4")
ALL_MODELS = [m.strip() for m in _env_models.split(",") if m.strip()]
if not ALL_MODELS:
    ALL_MODELS = ["gpt-3.5-turbo"]

# Token counter using tiktoken
@functools.lru_cache(maxsize=64)
def _get_encoding(model: str):
    try:
        return tiktoken.encoding_for_model(model)
    except KeyError:
        return tiktoken.get_encoding("cl100k_base")

def count_tokens(text: str, model: str) -> int:
    enc = _get_encoding(model)
    return len(enc.encode(text))

# Read uploaded file content (filepath or file-like)
def read_file_content(file_obj):
    if not file_obj:
        return ""
    try:
        if isinstance(file_obj, str):
            with open(file_obj, 'r', encoding='utf-8') as f:
                content = f.read()
            name = os.path.basename(file_obj)
        else:
            name = getattr(file_obj, 'name', 'uploaded_file')
            content = file_obj.read().decode('utf-8')
        return f"\n\n--- Start of file: {name} ---\n{content}\n--- End of file ---\n"
    except Exception:
        return ""

# Chat response function
def respond(message, history, model_name, file_obj):
    history = history or []
    file_text = read_file_content(file_obj)
    full_input = (message or "") + file_text

    # Call OpenAI chat completion
    messages = [{"role": "user", "content": full_input}]
    resp = client.chat.completions.create(
        model=model_name,
        messages=messages
    )
    reply = resp.choices[0].message.content

    # Token usage estimates
    prompt_tokens = count_tokens(full_input, model_name)
    completion_tokens = count_tokens(reply, model_name)
    usage_info = f"(Tokens used: prompt={prompt_tokens}, completion={completion_tokens})"

    # Update history (OpenAI-style messages)
    history.append({"role": "user", "content": message})
    history.append({"role": "assistant", "content": f"{reply}\n\n{usage_info}"})
    return history

# Build Gradio interface
with gr.Blocks() as demo:
    model_dropdown = gr.Dropdown(ALL_MODELS, value=ALL_MODELS[0], label="Select Model")
    chatbot = gr.Chatbot(label="Chat with AI", type="messages")
    file_upload = gr.File(label="Upload File", file_types=[".txt", ".md", ".py"], type="filepath")
    user_input = gr.Textbox(placeholder="Type your message...", show_label=False)

    user_input.submit(respond,
                      inputs=[user_input, chatbot, model_dropdown, file_upload],
                      outputs=chatbot)
    user_input.submit(lambda: "", None, user_input)

    demo.launch()