byinab commited on
Commit
f396ce7
Β·
verified Β·
1 Parent(s): 771c66e

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +94 -106
src/streamlit_app.py CHANGED
@@ -3,59 +3,43 @@ import torch
3
  from transformers import pipeline
4
  import time
5
 
6
- # Page config
7
  st.set_page_config(
8
  page_title="πŸ“§ Email Reply Assistant",
9
  page_icon="πŸ“§",
10
- layout="wide",
11
- initial_sidebar_state="expanded"
12
  )
13
 
14
- # Custom CSS
15
  st.markdown("""
16
  <style>
17
- .main-header { font-size: 3rem; color: #1f77b4; text-align: center; margin-bottom: 2rem; }
18
- .pipeline-card { background: #f8f9fa; padding: 1.5rem; border-radius: 10px; border-left: 5px solid #1f77b4; }
19
- .status-success { color: #28a745; font-weight: bold; }
20
- .status-loading { color: #ffc107; font-weight: bold; }
 
21
  </style>
22
  """, unsafe_allow_html=True)
23
 
24
  @st.cache_resource
25
  def load_pipelines():
26
- """Load all 3 pipelines with GPU support"""
27
- with st.spinner("Loading AI models... This takes ~2 minutes"):
28
- classifier = pipeline(
29
- "text-classification",
30
- model="distilbert-base-uncased",
31
- device=0 if torch.cuda.is_available() else -1
32
- )
33
- generator = pipeline(
34
- "text-generation",
35
- model="Kunal7370944861/Email-Writer-AI",
36
- device=0 if torch.cuda.is_available() else -1
37
- )
38
- translator = pipeline(
39
- "translation",
40
- model="DDDSSS/translation_en-zh",
41
- device=0 if torch.cuda.is_available() else -1
42
- )
43
  return classifier, generator, translator
44
 
45
- # Load pipelines
46
  try:
47
  classifier, generator, translator = load_pipelines()
48
- st.success("βœ… All 3 AI pipelines loaded!")
49
  except Exception as e:
50
- st.error(f"❌ Pipeline loading failed: {str(e)}")
51
  st.stop()
52
 
53
- # Helper functions (same logic)
54
- def classify_email(text: str, classifier):
55
  result = classifier(text[:512])[0]
56
  return result["label"], float(result["score"])
57
 
58
- def build_prompt(email_text: str, category: str) -> str:
59
  return f"""You are a helpful customer service agent.
60
  Email category: {category}
61
 
@@ -65,101 +49,105 @@ Customer email:
65
  Write a polite, concise reply template.
66
  Reply:"""
67
 
68
- def generate_reply(prompt: str, generator):
69
- outputs = generator(
70
- prompt,
71
- max_length=300,
72
- num_return_sequences=1,
73
- do_sample=True,
74
- temperature=0.7,
75
- pad_token_id=generator.tokenizer.eos_token_id
76
- )
77
  full_text = outputs[0]["generated_text"]
78
  if "Reply:" in full_text:
79
  return full_text.split("Reply:", 1)[-1].strip()
80
  return full_text.replace(prompt, "").strip()
81
 
82
- def translate_reply(text: str, translator):
83
- if not text.strip():
84
- return ""
85
- outputs = translator(text)
86
- return outputs[0]["translation_text"].strip()
87
 
88
- # Main title
89
  st.markdown('<h1 class="main-header">πŸ€– Email Reply Assistant</h1>', unsafe_allow_html=True)
90
- st.markdown("**Classify β†’ Generate Reply β†’ Translate to Chinese** β€’ Powered by 3 Transformer models")
91
 
92
- # Sidebar info
93
  with st.sidebar:
94
- st.header("ℹ️ Pipeline Info")
95
- st.markdown("""
96
- **Pipeline 1**: `distilbert-base-uncased` β†’ Email classification
97
- **Pipeline 2**: `Kunal7370944861/Email-Writer-AI` β†’ Reply generation
98
- **Pipeline 3**: `DDDSSS/translation_en-zh` β†’ Chinese translation
99
-
100
- **Status**: βœ… All models loaded
101
- """)
102
  st.markdown("---")
103
- st.info("πŸ‘ˆ Paste email β†’ Click Process β†’ Get instant replies!")
104
 
105
- # Main content
106
  col1, col2 = st.columns([1, 2])
107
 
108
  with col1:
109
- st.header("πŸ“¨ Input Email")
110
-
111
- # Email input
112
  email_text = st.text_area(
113
- "Paste your email here...",
114
- placeholder="Subject: Problem with order\n\nHello,\n\nMy package arrived damaged...",
115
- height=200,
116
- help="Paste complete email (subject + body)"
117
  )
118
 
119
- col_btn1, col_btn2 = st.columns(2)
120
- with col_btn1:
121
- process_btn = st.button("πŸš€ PROCESS EMAIL", type="primary", use_container_width=True)
122
- with col_btn2:
123
- if st.button("🧹 CLEAR", use_container_width=True):
124
- st.rerun()
 
 
125
 
126
  with col2:
127
- if process_btn and email_text.strip():
128
- with st.spinner("πŸ”„ Processing through 3 AI pipelines..."):
129
- # Pipeline 1: Classification
130
- with st.container():
131
- st.markdown('<div class="pipeline-card"><h3>πŸ”’ Pipeline 1: Classification</h3>', unsafe_allow_html=True)
132
- label, score = classify_email(email_text, classifier)
133
- col_a, col_b = st.columns(2)
134
- col_a.metric("Recommended Tag", label)
135
- col_b.metric("Confidence", f"{score:.1%}")
136
- st.markdown('</div>', unsafe_allow_html=True)
137
-
138
- # Pipeline 2: Reply Generation
139
- with st.container():
140
- st.markdown('<div class="pipeline-card"><h3>βœ‰οΈ Pipeline 2: English Reply</h3>', unsafe_allow_html=True)
141
- prompt = build_prompt(email_text, label)
142
- reply_en = generate_reply(prompt, generator)
143
- st.text_area("English Reply Template", reply_en, height=150, disabled=True)
144
- st.markdown('</div>', unsafe_allow_html=True)
145
-
146
- # Pipeline 3: Translation
147
- with st.container():
148
- st.markdown('<div class="pipeline-card"><h3>πŸ‡¨πŸ‡³ Pipeline 3: Chinese Translation</h3>', unsafe_allow_html=True)
149
- reply_zh = translate_reply(reply_en, translator)
150
- st.text_area("δΈ­ζ–‡ε›žε€ζ¨‘ζΏ", reply_zh, height=150, disabled=True)
151
- st.markdown('</div>', unsafe_allow_html=True)
152
 
153
- # Copy buttons
154
- st.markdown("---")
155
- col_c, col_d = st.columns(2)
156
- with col_c:
157
- st.download_button("πŸ“₯ Download English", reply_en, "reply_en.txt")
158
- with col_d:
159
- st.download_button("πŸ“₯ Download Chinese", reply_zh, "reply_zh.txt")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
160
  else:
161
- st.info("πŸ‘† **Paste an email and click 'PROCESS EMAIL'** to see the magic!")
 
 
162
 
163
- # Footer
164
  st.markdown("---")
165
- st.markdown("*Built with ❀️ using Streamlit + Transformers β€’ Deployed on Hugging Face Spaces*")
 
3
  from transformers import pipeline
4
  import time
5
 
 
6
  st.set_page_config(
7
  page_title="πŸ“§ Email Reply Assistant",
8
  page_icon="πŸ“§",
9
+ layout="wide"
 
10
  )
11
 
 
12
  st.markdown("""
13
  <style>
14
+ .main-header {font-size: 3rem; color: #1f77b4; text-align: center;}
15
+ .pipeline-card {background: linear-gradient(135deg, #f8f9fa 0%, #e9ecef 100%);
16
+ padding: 1.5rem; border-radius: 15px; margin: 1rem 0;
17
+ border-left: 6px solid #1f77b4; box-shadow: 0 4px 6px rgba(0,0,0,0.1);}
18
+ .metric-card {background: white; padding: 1rem; border-radius: 10px; text-align: center;}
19
  </style>
20
  """, unsafe_allow_html=True)
21
 
22
  @st.cache_resource
23
  def load_pipelines():
24
+ with st.spinner('πŸ”„ Loading AI models (2-3 min)...'):
25
+ classifier = pipeline("text-classification", model="distilbert-base-uncased")
26
+ generator = pipeline("text-generation", model="Kunal7370944861/Email-Writer-AI")
27
+ translator = pipeline("translation", model="DDDSSS/translation_en-zh")
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  return classifier, generator, translator
29
 
30
+ # Load models safely
31
  try:
32
  classifier, generator, translator = load_pipelines()
33
+ st.success("βœ… All 3 pipelines ready!")
34
  except Exception as e:
35
+ st.error(f"Model loading error: {str(e)}")
36
  st.stop()
37
 
38
+ def classify_email(text, classifier):
 
39
  result = classifier(text[:512])[0]
40
  return result["label"], float(result["score"])
41
 
42
+ def build_prompt(email_text, category):
43
  return f"""You are a helpful customer service agent.
44
  Email category: {category}
45
 
 
49
  Write a polite, concise reply template.
50
  Reply:"""
51
 
52
+ def generate_reply(prompt, generator):
53
+ outputs = generator(prompt, max_length=300, num_return_sequences=1,
54
+ do_sample=True, temperature=0.7)
 
 
 
 
 
 
55
  full_text = outputs[0]["generated_text"]
56
  if "Reply:" in full_text:
57
  return full_text.split("Reply:", 1)[-1].strip()
58
  return full_text.replace(prompt, "").strip()
59
 
60
+ def translate_reply(text, translator):
61
+ if not text.strip(): return ""
62
+ return translator(text)[0]["translation_text"].strip()
 
 
63
 
64
+ # Header
65
  st.markdown('<h1 class="main-header">πŸ€– Email Reply Assistant</h1>', unsafe_allow_html=True)
66
+ st.markdown("**AI-powered: Classify β†’ Generate Reply β†’ Translate to Chinese**")
67
 
68
+ # Sidebar
69
  with st.sidebar:
70
+ st.header("πŸ”§ Pipeline Status")
71
+ st.success("βœ… **Pipeline 1**: `distilbert-base-uncased`")
72
+ st.success("βœ… **Pipeline 2**: `Kunal7370944861/Email-Writer-AI`")
73
+ st.success("βœ… **Pipeline 3**: `DDDSSS/translation_en-zh`")
 
 
 
 
74
  st.markdown("---")
75
+ st.info("πŸ‘ˆ **Paste email β†’ Process β†’ Copy replies!**")
76
 
77
+ # Main layout
78
  col1, col2 = st.columns([1, 2])
79
 
80
  with col1:
81
+ st.header("πŸ“¨ **Input Email**")
 
 
82
  email_text = st.text_area(
83
+ "Paste complete email here...",
84
+ placeholder="Subject: Order Issue\n\nHello,\nMy package arrived damaged...",
85
+ height=220
 
86
  )
87
 
88
+ if st.button("πŸš€ **PROCESS EMAIL**", type="primary", use_container_width=True):
89
+ if email_text.strip():
90
+ st.session_state.processed = True
91
+ st.session_state.email = email_text
92
+ else:
93
+ st.error("❌ Please paste an email first!")
94
+ if st.button("🧹 **CLEAR**", use_container_width=True):
95
+ st.rerun()
96
 
97
  with col2:
98
+ if 'processed' in st.session_state and st.session_state.processed:
99
+ email_text = st.session_state.email
100
+
101
+ # Pipeline 1: Classification
102
+ with st.container():
103
+ st.markdown('<div class="pipeline-card">', unsafe_allow_html=True)
104
+ st.markdown("### πŸ”’ **Pipeline 1: Email Classification**")
105
+ label, score = classify_email(email_text, classifier)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
 
107
+ col_a, col_b = st.columns(2)
108
+ with col_a:
109
+ st.markdown(f"""
110
+ <div class="metric-card">
111
+ <h3>🏷️ Tag</h3>
112
+ <h2>{label}</h2>
113
+ </div>
114
+ """, unsafe_allow_html=True)
115
+ with col_b:
116
+ st.markdown(f"""
117
+ <div class="metric-card">
118
+ <h3>πŸ“Š Confidence</h3>
119
+ <h2>{score:.1%}</h2>
120
+ </div>
121
+ """, unsafe_allow_html=True)
122
+ st.markdown('</div>', unsafe_allow_html=True)
123
+
124
+ # Pipeline 2: English Reply
125
+ with st.container():
126
+ st.markdown('<div class="pipeline-card">', unsafe_allow_html=True)
127
+ st.markdown("### βœ‰οΈ **Pipeline 2: English Reply**")
128
+ prompt = build_prompt(email_text, label)
129
+ reply_en = generate_reply(prompt, generator)
130
+ st.text_area("**Reply Template**", reply_en, height=140, disabled=True)
131
+ st.markdown('</div>', unsafe_allow_html=True)
132
+
133
+ # Pipeline 3: Chinese Translation
134
+ with st.container():
135
+ st.markdown('<div class="pipeline-card">', unsafe_allow_html=True)
136
+ st.markdown("### πŸ‡¨πŸ‡³ **Pipeline 3: Chinese Translation**")
137
+ reply_zh = translate_reply(reply_en, translator)
138
+ st.text_area("**δΈ­ζ–‡ε›žε€**", reply_zh, height=140, disabled=True)
139
+ st.markdown('</div>', unsafe_allow_html=True)
140
+
141
+ # Download buttons
142
+ col_c, col_d = st.columns(2)
143
+ with col_c:
144
+ st.download_button("πŸ“₯ Download English", reply_en, "email_reply_en.txt", use_container_width=True)
145
+ with col_d:
146
+ st.download_button("πŸ“₯ Download Chinese", reply_zh, "email_reply_zh.txt", use_container_width=True)
147
  else:
148
+ st.markdown('<div class="pipeline-card">', unsafe_allow_html=True)
149
+ st.info("🎯 **Paste your email above and click PROCESS**")
150
+ st.markdown('</div>', unsafe_allow_html=True)
151
 
 
152
  st.markdown("---")
153
+ st.markdown("*Powered by Streamlit + Transformers | Deployed on Hugging Face Spaces*")