Spaces:
Sleeping
Sleeping
File size: 4,218 Bytes
b26a345 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 | import os
from typing import List, Tuple
import gradio as gr
from openai import OpenAI
import dotenv
dotenv.load_dotenv()
API_KEY = os.getenv("OPENAI_KEY")
MODEL = os.getenv("OPENAI_MODEL", "gpt-5-mini")
VECTOR_STORE_ID = os.getenv("OPENAI_VECTOR_STORE_ID") # required for file_search
if not API_KEY:
raise RuntimeError("Please set OPENAI_API_KEY.")
if not VECTOR_STORE_ID:
print("⚠️ OPENAI_VECTOR_STORE_ID not set. File search will be disabled.")
client = OpenAI(api_key=API_KEY)
SYSTEM_PROMPT = (
"You are a virtual teacher assistant. Use your knowledge to answer questions only found in the course documents provided. When relevant, and cite files you used. If "
"something isn't in the corpus, say so."
)
def call_model(history: List[dict], user_msg: str) -> Tuple[str, List[str]]:
"""
Calls Responses API with optional file_search over a vector store.
Returns (text, citations).
"""
# Build messages for the Responses API
messages = [{"role": "system", "content": SYSTEM_PROMPT}]
messages.extend(history) # each should be {"role": "user"/"assistant", "content": "..."}
messages.append({"role": "user", "content": user_msg})
kwargs = {"model": MODEL,
"input": messages,
"temperature": 1.0,
"include": ["file_search_call.results"]}
if VECTOR_STORE_ID:
kwargs["tools"] = [{"type": "file_search", "vector_store_ids": [VECTOR_STORE_ID]}]
resp = client.responses.create(**kwargs)
# --- Extract text
# New SDKs expose a convenience `output_text`; fall back to walking the structure.
text = ""
citations = []
try:
text = resp.output_text
citations = [a.filename for a in resp.output[3].content[0].annotations]
except Exception as e:
print(e)
citations = list(set(citations))
return text.strip(), citations
# --- Gradio callbacks (chat history is kept client-side)
def startup():
welcome = gr.ChatMessage(role="assistant", content="Hi! I’m your course TA. Ask me about the syllabus, lectures, or assignments.")
return [welcome], []
def chat(user_msg: str, history: List[gr.ChatMessage], raw_history: List[dict]):
# Convert visible history to simple message dicts for the API
if raw_history is None:
raw_history = []
updated_history = list(raw_history)
# Call model
text, citations = call_model(updated_history, user_msg)
# Append the new exchange to both raw_history and visible chat
updated_history.append({"role": "user", "content": user_msg})
assistant_content = text if text else "(No text content returned.)"
if citations:
assistant_content += "\n\n**Sources:**\n" + "\n".join(f"- {fname}" for fname in citations)
updated_history.append({"role": "assistant", "content": assistant_content})
history = history + [gr.ChatMessage(role="user", content=user_msg),
gr.ChatMessage(role="assistant", content=assistant_content)]
return history, updated_history
def reset_chat():
return [gr.ChatMessage(role="assistant", content="New chat started. How can I help?")], []
with gr.Blocks(title="Course TA (Responses API + File Search)") as demo:
gr.Markdown("## 🎒 Course TA\nChat with a TA connected to your course corpus (Responses API).")
chatbox = gr.Chatbot(type="messages", show_copy_button=True, label="Assistant")
raw_history_state = gr.State([]) # holds minimal [{"role":..., "content":...}] history
txt = gr.Textbox(placeholder="Ask about the syllabus, due dates, lecture notes, etc.", autofocus=True)
send_btn = gr.Button("Send", variant="primary")
reset_btn = gr.Button("Start New Chat")
demo.load(startup, outputs=[chatbox, raw_history_state])
send_btn.click(chat, inputs=[txt, chatbox, raw_history_state], outputs=[chatbox, raw_history_state])
txt.submit(chat, inputs=[txt, chatbox, raw_history_state], outputs=[chatbox, raw_history_state])
reset_btn.click(reset_chat, outputs=[chatbox, raw_history_state])
# if __name__ == "__main__":
# demo.launch(share=True)
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=int(os.getenv("PORT", 7860)))
|