File size: 4,218 Bytes
b26a345
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import os
from typing import List, Tuple

import gradio as gr
from openai import OpenAI
import dotenv

dotenv.load_dotenv()

API_KEY = os.getenv("OPENAI_KEY")
MODEL = os.getenv("OPENAI_MODEL", "gpt-5-mini")
VECTOR_STORE_ID = os.getenv("OPENAI_VECTOR_STORE_ID")  # required for file_search

if not API_KEY:
    raise RuntimeError("Please set OPENAI_API_KEY.")
if not VECTOR_STORE_ID:
    print("⚠️  OPENAI_VECTOR_STORE_ID not set. File search will be disabled.")

client = OpenAI(api_key=API_KEY)

SYSTEM_PROMPT = (
    "You are a virtual teacher assistant. Use your knowledge to answer questions only found in the course documents provided. When relevant, and cite files you used. If "
    "something isn't in the corpus, say so."
)

def call_model(history: List[dict], user_msg: str) -> Tuple[str, List[str]]:
    """
    Calls Responses API with optional file_search over a vector store.
    Returns (text, citations).
    """
    # Build messages for the Responses API
    messages = [{"role": "system", "content": SYSTEM_PROMPT}]
    messages.extend(history)  # each should be {"role": "user"/"assistant", "content": "..."}
    messages.append({"role": "user", "content": user_msg})

    kwargs = {"model": MODEL,
              "input": messages,
              "temperature": 1.0,
              "include": ["file_search_call.results"]}

    if VECTOR_STORE_ID:
        kwargs["tools"] = [{"type": "file_search", "vector_store_ids": [VECTOR_STORE_ID]}]

    resp = client.responses.create(**kwargs)

    # --- Extract text
    # New SDKs expose a convenience `output_text`; fall back to walking the structure.
    text = ""
    citations = []
    try:
        text = resp.output_text
        citations = [a.filename for a in resp.output[3].content[0].annotations]
    except Exception as e:
        print(e)
    citations = list(set(citations))
    return text.strip(), citations

# --- Gradio callbacks (chat history is kept client-side)
def startup():
    welcome = gr.ChatMessage(role="assistant", content="Hi! I’m your course TA. Ask me about the syllabus, lectures, or assignments.")
    return [welcome], []

def chat(user_msg: str, history: List[gr.ChatMessage], raw_history: List[dict]):
    # Convert visible history to simple message dicts for the API
    if raw_history is None:
        raw_history = []
    updated_history = list(raw_history)

    # Call model
    text, citations = call_model(updated_history, user_msg)

    # Append the new exchange to both raw_history and visible chat
    updated_history.append({"role": "user", "content": user_msg})
    assistant_content = text if text else "(No text content returned.)"
    if citations:
        assistant_content += "\n\n**Sources:**\n" + "\n".join(f"- {fname}" for fname in citations)
    updated_history.append({"role": "assistant", "content": assistant_content})

    history = history + [gr.ChatMessage(role="user", content=user_msg),
                         gr.ChatMessage(role="assistant", content=assistant_content)]
    return history, updated_history

def reset_chat():
    return [gr.ChatMessage(role="assistant", content="New chat started. How can I help?")], []


with gr.Blocks(title="Course TA (Responses API + File Search)") as demo:
    gr.Markdown("## 🎒 Course TA\nChat with a TA connected to your course corpus (Responses API).")

    chatbox = gr.Chatbot(type="messages", show_copy_button=True, label="Assistant")
    raw_history_state = gr.State([])  # holds minimal [{"role":..., "content":...}] history

    txt = gr.Textbox(placeholder="Ask about the syllabus, due dates, lecture notes, etc.", autofocus=True)
    send_btn = gr.Button("Send", variant="primary")
    reset_btn = gr.Button("Start New Chat")

    demo.load(startup, outputs=[chatbox, raw_history_state])
    send_btn.click(chat, inputs=[txt, chatbox, raw_history_state], outputs=[chatbox, raw_history_state])
    txt.submit(chat, inputs=[txt, chatbox, raw_history_state], outputs=[chatbox, raw_history_state])
    reset_btn.click(reset_chat, outputs=[chatbox, raw_history_state])

# if __name__ == "__main__":
#     demo.launch(share=True)
if __name__ == "__main__":
    demo.launch(server_name="0.0.0.0", server_port=int(os.getenv("PORT", 7860)))