AgileTaskGen-App / streamlit_app.py
Razieh87's picture
Update streamlit_app.py
17b71fc verified
# -*- coding: utf-8 -*-
# app.py
import os
import json
import pickle
import streamlit as st
from typing import List, Any, Literal
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel
from langchain_core.messages import AIMessage
from pydantic import BaseModel, Field
from typing import Literal as TypingLiteral
from langchain_core.documents import Document
from langchain_core.messages import SystemMessage, HumanMessage
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_openai import OpenAIEmbeddings, ChatOpenAI
from langchain.chat_models import init_chat_model
from langchain.tools import tool
from langchain_core.vectorstores import InMemoryVectorStore
from langgraph.graph import StateGraph, START, END, MessagesState
from langgraph.prebuilt import ToolNode, tools_condition
from langgraph.checkpoint.memory import MemorySaver
# -----------------------------
# Streamlit page config
# -----------------------------
AGILE_ICON_URL = "https://raw.githubusercontent.com/RaziehAkbari2020/Agentic_RAG_Agile_Task_Management.Streamlit/main/Agile.png"
st.set_page_config(
page_title="Agentic RAG for Agile Task Management",
page_icon=AGILE_ICON_URL,
layout="wide"
)
st.markdown(
f"""
<h1 style="display:flex; align-items:center; gap:12px; font-size:32px;">
<img src="{AGILE_ICON_URL}" width="170">
Agentic RAG for Agile Task Management
</h1>
""",
unsafe_allow_html=True
)
st.caption(
"Upload your preprocessed Agile project data (Taiga, Jira, GitHub, etc.), "
"and chat with the AI Scrum Master assistant."
)
# -----------------------------
# Helpers: data loading
# -----------------------------
def _ensure_documents(obj: Any) -> List[Document]:
if obj is None:
return []
if isinstance(obj, list) and obj and all(isinstance(x, Document) for x in obj):
return obj
if isinstance(obj, dict) and "documents" in obj:
return _ensure_documents(obj["documents"])
if isinstance(obj, list) and (len(obj) == 0 or isinstance(obj[0], dict)):
docs: List[Document] = []
for d in obj:
page = d.get("page_content") or d.get("content") or d.get("text") or ""
meta = d.get("metadata") or {}
docs.append(Document(page_content=str(page), metadata=meta))
return docs
raise ValueError(
"Uploaded data format not recognized. Expected a list of LangChain Documents, "
"or a list of dicts with page_content/metadata, or a dict with key 'documents'."
)
def load_uploaded_data(uploaded_file) -> List[Document]:
name = uploaded_file.name.lower()
if name.endswith(".pkl") or name.endswith(".pickle"):
obj = pickle.load(uploaded_file)
return _ensure_documents(obj)
if name.endswith(".json"):
obj = json.load(uploaded_file)
return _ensure_documents(obj)
if name.endswith(".jsonl"):
lines = uploaded_file.getvalue().decode("utf-8").splitlines()
arr = []
for ln in lines:
ln = ln.strip()
if not ln:
continue
arr.append(json.loads(ln))
return _ensure_documents(arr)
raise ValueError("Unsupported file type. Please upload .pkl, .json, or .jsonl")
# -----------------------------
# Prompts
# -----------------------------
SYSTEM_PROMPT = """
You are an Agile assistant answering questions about Taiga project data.
You have access to a tool that retrieves user stories and tasks.
Use the tool when:
- the user asks about a specific user story
- the user asks about tasks, effort, assignees, sprint, story points
- the answer is not fully available in the conversation memory
- the question introduces new entities
Do NOT use the tool when:
- the question clearly refers to the immediately previous answer
- and all necessary information is already present in the conversation
Always prefer grounded answers based on retrieved context when uncertain.
""".strip()
GRADE_PROMPT = (
"You are a grader assessing relevance of a retrieved Taiga Agile project document to a user question.\n\n"
"The document is a Taiga user story and may contain:\n"
"- User Story description\n"
"- Tasks\n"
"- Story points\n"
"- Estimated effort\n"
"- Actual effort\n"
"- Assignees\n"
"- Sprint order\n\n"
"Retrieved document:\n"
"{context}\n\n"
"User question:\n"
"{question}\n\n"
"If the document contains information useful for answering the question "
"(especially about tasks, effort, story points, assignment, sprint, or task management), "
"grade it as relevant.\n\n"
"Respond ONLY with 'yes' or 'no'."
)
REWRITE_PROMPT = (
"You are rewriting a user question ONLY to improve retrieval over Taiga user stories and tasks.\n\n"
"STRICT RULES:\n"
"1) DO NOT change the meaning.\n"
"2) DO NOT introduce new user story names, new entities, or new IDs.\n"
"3) If the question mentions a specific user story, you MUST keep it exactly.\n"
"4) If the question is already clear and retrieval-ready, return it unchanged.\n"
"5) Keep it as ONE sentence. No lists, no extra commentary.\n\n"
"Original question:\n"
"{question}\n\n"
"Rewritten question:"
)
GENERATE_PROMPT = (
"You are an Agile assistant specialized in task planning and analysis.\n"
"Use ONLY the retrieved context to answer.\n"
"If the question involves tasks, provide a clear and structured plan when appropriate.\n"
"NEVER infer missing fields.\n"
"If the question asks for a field that is not explicitly present in the context, say 'Not provided in the data.'\n"
"Do NOT treat Sprint Order as Priority unless the context explicitly says so.\n"
"Provide a concise but structured answer.\n\n"
"Question: {question}\n"
"Context: {context}"
)
DIRECT_GENERATE_PROMPT = (
"You are an Agile assistant specialized in task planning and analysis.\n"
"Answer the user question directly.\n"
"If the question requires project-specific data and no retrieved context is available, "
"say that project data is needed.\n"
"Do NOT invent project-specific details.\n"
"Provide a concise but structured answer.\n\n"
"Question: {question}"
)
class GradeDocuments(BaseModel):
binary_score: TypingLiteral["yes", "no"] = Field(
description="Relevance score: 'yes' if relevant, or 'no' if not relevant"
)
def get_last_human_text(messages) -> str:
for m in reversed(messages):
if isinstance(m, HumanMessage):
return m.content
return messages[0].content if messages else ""
BASE_MODEL = "Qwen/Qwen3-0.6B"
LORA_MODEL = "Razieh87/AgileTaskGen-Agent-Qwen3-0.6B"
@st.cache_resource(show_spinner=True)
def load_qwen_ft_model():
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, trust_remote_code=True)
base_model = AutoModelForCausalLM.from_pretrained(
BASE_MODEL,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
device_map="auto",
trust_remote_code=True,
)
model = PeftModel.from_pretrained(base_model, LORA_MODEL)
model.eval()
return tokenizer, model
class QwenFTAnswerModel:
def invoke(self, messages):
tokenizer, model = load_qwen_ft_model()
if isinstance(messages[-1], dict):
user_prompt = messages[-1]["content"]
else:
user_prompt = messages[-1].content
inputs = tokenizer(user_prompt, return_tensors="pt").to(model.device)
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=512,
do_sample=False,
pad_token_id=tokenizer.eos_token_id,
)
text = tokenizer.decode(outputs[0], skip_special_tokens=True)
if user_prompt in text:
answer = text.replace(user_prompt, "").strip()
else:
answer = text.strip()
return AIMessage(content=answer)
@st.cache_resource(show_spinner=False)
def build_graph_from_documents(docs: List[Document]):
text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
chunk_size=2500,
chunk_overlap=100,
)
doc_splits = text_splitter.split_documents(docs)
vectorstore = InMemoryVectorStore.from_documents(
documents=doc_splits,
embedding=OpenAIEmbeddings(),
)
retriever = vectorstore.as_retriever(search_kwargs={"k": 8})
@tool
def retrieve_user_stories(query: str) -> str:
"""Search and return relevant Taiga user stories and their tasks."""
found = retriever.invoke(query)
return "\n\n".join(d.page_content for d in found)
retriever_tool = retrieve_user_stories
# -----------------------------
# Models
# -----------------------------
# OpenAI controls routing, tool calling, grading, and rewriting
control_model = init_chat_model("gpt-4o-mini", temperature=0)
answer_model = QwenFTAnswerModel()
response_model = control_model
grader_model = control_model
# -----------------------------
# Nodes
# -----------------------------
def generate_query_or_respond(state: MessagesState):
messages = state["messages"]
messages_with_system = [SystemMessage(content=SYSTEM_PROMPT), *messages]
msg = response_model.bind_tools([retriever_tool]).invoke(messages_with_system)
return {"messages": [msg]}
def grade_documents(state: MessagesState) -> Literal["generate_answer", "rewrite_question"]:
question = get_last_human_text(state["messages"])
context = state["messages"][-1].content
prompt = GRADE_PROMPT.format(question=question, context=context)
response = grader_model.with_structured_output(GradeDocuments).invoke(
[{"role": "user", "content": prompt}]
)
return "generate_answer" if response.binary_score == "yes" else "rewrite_question"
def rewrite_question(state: MessagesState):
question = get_last_human_text(state["messages"]).strip()
prompt = REWRITE_PROMPT.format(question=question)
rewritten = response_model.invoke(
[{"role": "user", "content": prompt}]
).content.strip()
if not rewritten:
rewritten = question
return {"messages": [HumanMessage(content=rewritten)]}
def generate_answer(state: MessagesState):
question = get_last_human_text(state["messages"])
context = state["messages"][-1].content
prompt = GENERATE_PROMPT.format(question=question, context=context)
response = answer_model.invoke(
[{"role": "user", "content": prompt}]
)
return {"messages": [response]}
def generate_answer_without_context(state: MessagesState):
question = get_last_human_text(state["messages"])
prompt = DIRECT_GENERATE_PROMPT.format(question=question)
response = answer_model.invoke(
[{"role": "user", "content": prompt}]
)
return {"messages": [response]}
# -----------------------------
# Graph
# -----------------------------
workflow = StateGraph(MessagesState)
workflow.add_node("generate_query_or_respond", generate_query_or_respond)
workflow.add_node("retrieve", ToolNode([retriever_tool]))
workflow.add_node("rewrite_question", rewrite_question)
workflow.add_node("generate_answer", generate_answer)
workflow.add_node("generate_answer_without_context", generate_answer_without_context)
workflow.add_edge(START, "generate_query_or_respond")
workflow.add_conditional_edges(
"generate_query_or_respond",
tools_condition,
{
"tools": "retrieve",
END: "generate_answer_without_context",
},
)
workflow.add_conditional_edges(
"retrieve",
grade_documents,
{
"generate_answer": "generate_answer",
"rewrite_question": "rewrite_question",
},
)
workflow.add_edge("generate_answer", END)
workflow.add_edge("generate_answer_without_context", END)
workflow.add_edge("rewrite_question", "generate_query_or_respond")
memory = MemorySaver()
graph = workflow.compile(checkpointer=memory)
return graph
# -----------------------------
# Sidebar: API key + upload + settings
# -----------------------------
with st.sidebar:
st.header("⚙️ Setup")
import os
try:
api_key = st.secrets["OPENAI_API_KEY"]
os.environ["OPENAI_API_KEY"] = api_key
except Exception:
api_key = ""
if not api_key:
st.error("OpenAI API key is not configured.")
st.stop()
uploaded = st.file_uploader(
"Upload your preprocessed RAG dataset",
type=["pkl", "pickle", "json", "jsonl"],
help="Upload the dataset that is already prepared for RAG documents.",
)
show_retrieved = st.toggle("Show retrieved context", value=True)
tool_preview_chars = st.slider(
"Retrieved context preview chars",
200,
3000,
900,
50
)
thread_id = st.text_input(
"Chat session",
value=st.session_state.get("thread_id", "Task-thread-1"),
help="Keeps conversation memory. Change it to start a new session."
)
st.session_state["thread_id"] = thread_id
if st.button("🧹 Reset chat"):
st.session_state.pop("chat_messages", None)
st.session_state.pop("last_tool_context", None)
st.session_state.pop("thread_id", None)
st.rerun()
# -----------------------------
# Load docs + build graph
# -----------------------------
graph = None
docs = None
error = None
if not api_key:
st.warning("Please enter your OpenAI API key in the sidebar.")
else:
if uploaded is None:
st.info("Upload your preprocessed Taiga dataset to start.")
else:
try:
docs = load_uploaded_data(uploaded)
st.success(f"Loaded {len(docs)} documents.")
graph = build_graph_from_documents(docs)
except Exception as e:
error = str(e)
st.error(f"Failed to load/build: {error}")
# -----------------------------
# Chat UI
# -----------------------------
if "chat_messages" not in st.session_state:
st.session_state["chat_messages"] = []
if "last_tool_context" not in st.session_state:
st.session_state["last_tool_context"] = ""
for m in st.session_state["chat_messages"]:
with st.chat_message(m["role"]):
st.markdown(m["content"])
prompt = st.chat_input(
"Ask about tasks, effort estimation, assignment, workload, or progress..."
)
if prompt and graph is not None:
st.session_state["chat_messages"].append(
{"role": "user", "content": prompt}
)
with st.chat_message("user"):
st.markdown(prompt)
config = {"configurable": {"thread_id": st.session_state["thread_id"]}}
tool_context_accum = ""
assistant_text = ""
with st.chat_message("assistant"):
placeholder = st.empty()
for step in graph.stream(
{"messages": [HumanMessage(content=prompt)]},
stream_mode="values",
config=config,
):
last_msg = step["messages"][-1]
msg_type = getattr(last_msg, "type", None)
if msg_type == "tool":
content = last_msg.content or ""
if len(content) > tool_preview_chars:
content = content[:tool_preview_chars] + "\n… [truncated]"
tool_context_accum = content
continue
if msg_type == "ai":
assistant_text = last_msg.content or ""
placeholder.markdown(assistant_text)
if show_retrieved and tool_context_accum:
with st.expander("🔎 Retrieved context (tool output)", expanded=False):
st.code(tool_context_accum)
st.session_state["chat_messages"].append(
{"role": "assistant", "content": assistant_text}
)
st.session_state["last_tool_context"] = tool_context_accum