JumaRubea commited on
Commit
043fc8b
·
verified ·
1 Parent(s): 01498e5

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +139 -38
src/streamlit_app.py CHANGED
@@ -1,40 +1,141 @@
1
- import altair as alt
2
- import numpy as np
3
- import pandas as pd
 
4
  import streamlit as st
 
5
 
6
- """
7
- # Welcome to Streamlit!
8
-
9
- Edit `/streamlit_app.py` to customize this app to your heart's desire :heart:.
10
- If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
11
- forums](https://discuss.streamlit.io).
12
-
13
- In the meantime, below is an example of what you can do with just a few lines of code:
14
- """
15
-
16
- num_points = st.slider("Number of points in spiral", 1, 10000, 1100)
17
- num_turns = st.slider("Number of turns in spiral", 1, 300, 31)
18
-
19
- indices = np.linspace(0, 1, num_points)
20
- theta = 2 * np.pi * num_turns * indices
21
- radius = indices
22
-
23
- x = radius * np.cos(theta)
24
- y = radius * np.sin(theta)
25
-
26
- df = pd.DataFrame({
27
- "x": x,
28
- "y": y,
29
- "idx": indices,
30
- "rand": np.random.randn(num_points),
31
- })
32
-
33
- st.altair_chart(alt.Chart(df, height=700, width=700)
34
- .mark_point(filled=True)
35
- .encode(
36
- x=alt.X("x", axis=None),
37
- y=alt.Y("y", axis=None),
38
- color=alt.Color("idx", legend=None, scale=alt.Scale()),
39
- size=alt.Size("rand", legend=None, scale=alt.Scale(range=[1, 150])),
40
- ))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import threading
3
+ import torch
4
+ import requests
5
  import streamlit as st
6
+ from chats import init_db, get_all_chats, create_new_chat, save_message, get_messages, system_prompt
7
 
8
+ # Set HF cache directory
9
+ os.environ["HF_HOME"] = "/tmp/huggingface_cache"
10
+
11
+ # ------------------ FASTAPI BACKEND ------------------
12
+ from fastapi import FastAPI
13
+ from fastapi.responses import StreamingResponse, JSONResponse
14
+ from pydantic import BaseModel
15
+ import uvicorn
16
+ from transformers import AutoTokenizer, AutoModelForCausalLM
17
+
18
+ app = FastAPI()
19
+
20
+ class GenerationRequest(BaseModel):
21
+ system_message: str
22
+ user_prompt: str
23
+
24
+ # Load model/tokenizer once
25
+ tokenizer = AutoTokenizer.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0")
26
+ model = AutoModelForCausalLM.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0")
27
+ device = "cuda" if torch.cuda.is_available() else "cpu"
28
+ model.to(device)
29
+
30
+ @app.post("/api/ai-generate")
31
+ async def generate_text_stream(request: GenerationRequest):
32
+ try:
33
+ messages = [
34
+ {"role": "system", "content": request.system_message},
35
+ {"role": "user", "content": request.user_prompt}
36
+ ]
37
+
38
+ inputs = tokenizer.apply_chat_template(
39
+ messages,
40
+ add_generation_prompt=True,
41
+ tokenize=True,
42
+ return_dict=True,
43
+ return_tensors="pt",
44
+ ).to(device)
45
+
46
+ def token_stream():
47
+ generated = inputs["input_ids"]
48
+ # Generate tokens with return_dict_in_generate=True to access sequences
49
+ outputs = model.generate(
50
+ **inputs,
51
+ max_new_tokens=200,
52
+ do_sample=False,
53
+ temperature=0.5,
54
+ top_p=0.9,
55
+ eos_token_id=None,
56
+ pad_token_id=tokenizer.eos_token_id,
57
+ return_dict_in_generate=True,
58
+ output_scores=False
59
+ )
60
+ sequence = outputs.sequences[0]
61
+ # Decode tokens one by one as they come after prompt length
62
+ for i in range(generated.shape[-1], sequence.shape[-1]):
63
+ token_id = sequence[i].unsqueeze(0)
64
+ text = tokenizer.decode(token_id, skip_special_tokens=True)
65
+ if text.strip():
66
+ yield text
67
+ yield "\n"
68
+
69
+ return StreamingResponse(token_stream(), media_type="text/plain")
70
+
71
+ except Exception as e:
72
+ return JSONResponse(status_code=500, content={"error": str(e)})
73
+
74
+ def start_fastapi():
75
+ uvicorn.run(app, host="0.0.0.0", port=8000)
76
+
77
+ # Start FastAPI server in background thread
78
+ threading.Thread(target=start_fastapi, daemon=True).start()
79
+
80
+ # ------------------ STREAMLIT FRONTEND ------------------
81
+
82
+ init_db()
83
+
84
+ st.set_page_config(page_title="AI Assistant", page_icon="🤖")
85
+ st.title("🤖 Juma's Assistant")
86
+
87
+ st.sidebar.title("💬 Previous Chats")
88
+ all_chats = get_all_chats()
89
+
90
+ chat_titles = [f"{title} (ID: {chat_id})" for chat_id, title in all_chats]
91
+ selected_chat_index = st.sidebar.selectbox(
92
+ "Select Chat", range(len(all_chats)), format_func=lambda i: chat_titles[i] if all_chats else "No chats available"
93
+ )
94
+
95
+ selected_chat_id = all_chats[selected_chat_index][0] if all_chats else None
96
+
97
+ if st.sidebar.button("🆕 Start New Chat"):
98
+ selected_chat_id = create_new_chat()
99
+ st.experimental_rerun()
100
+
101
+ if selected_chat_id is None:
102
+ st.warning("Please start a new chat or select one from the sidebar.")
103
+ st.stop()
104
+
105
+ messages = get_messages(selected_chat_id)
106
+ for role, content in messages:
107
+ with st.chat_message(role):
108
+ st.markdown(content)
109
+
110
+ user_input = st.chat_input("Type your message...")
111
+ if user_input:
112
+ st.chat_message("user").markdown(user_input)
113
+ save_message(selected_chat_id, "user", user_input)
114
+
115
+ with st.spinner("Thinking..."):
116
+ try:
117
+ response = requests.post(
118
+ "http://localhost:8000/api/ai-generate",
119
+ json={
120
+ "system_message": system_prompt(),
121
+ "user_prompt": user_input
122
+ },
123
+ stream=True,
124
+ timeout=120,
125
+ )
126
+
127
+ if response.status_code == 200:
128
+ full_response = ""
129
+ placeholder = st.empty()
130
+ # Stream tokens chunk by chunk
131
+ for chunk in response.iter_content(chunk_size=1):
132
+ if chunk:
133
+ decoded = chunk.decode("utf-8")
134
+ full_response += decoded
135
+ placeholder.markdown(full_response)
136
+ st.chat_message("assistant").markdown(full_response)
137
+ save_message(selected_chat_id, "assistant", full_response)
138
+ else:
139
+ st.error("API call failed.")
140
+ except Exception as e:
141
+ st.error(f"Error: {str(e)}")