File size: 6,994 Bytes
2e13d88
 
 
 
 
 
 
 
 
3931bcf
2e13d88
 
 
 
 
 
 
 
 
 
2028d24
2e13d88
 
2028d24
 
2e13d88
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2028d24
2e13d88
2028d24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2e13d88
 
 
 
 
 
2028d24
2e13d88
2028d24
 
2e13d88
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2028d24
2e13d88
 
 
 
 
 
2028d24
2e13d88
 
 
 
 
 
 
 
 
 
 
 
2028d24
2e13d88
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2028d24
 
2e13d88
 
2028d24
2e13d88
 
2028d24
 
2e13d88
 
 
 
 
 
 
 
 
 
2028d24
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
import gradio as gr
from transformers import pipeline
from functools import lru_cache

DEFAULT_LABELS = [
    "finance", "sports", "tech", "politics", "health", "entertainment",
    "science", "business", "travel", "education"
]


@lru_cache(maxsize=1)
def get_pipes():
    summarizer = pipeline(
        "summarization",
        model="sshleifer/distilbart-cnn-12-6"
    )
    zshot = pipeline(
        "zero-shot-classification",
        model="valhalla/distilbart-mnli-12-1"
    )
    # 3-class sentiment: NEGATIVE / NEUTRAL / POSITIVE
    sentiment = pipeline(
        "sentiment-analysis",
        model="cardiffnlp/twitter-roberta-base-sentiment-latest",
        tokenizer="cardiffnlp/twitter-roberta-base-sentiment-latest"
    )
    return summarizer, zshot, sentiment


def chunk_text(text: str, max_chars: int = 1600):
    """Naive chunker to keep inputs within summarizer limits.
    Splits on sentences by '. ' and groups into ~max_chars chunks.
    """
    sentences = [s.strip() for s in text.replace("\n", " ").split(". ") if s.strip()]
    chunks, buf = [], ""
    for s in sentences:
        add = (s + (". " if not s.endswith(".") else " "))
        if len(buf) + len(add) <= max_chars:
            buf += add
        else:
            if buf:
                chunks.append(buf.strip())
            buf = add
    if buf:
        chunks.append(buf.strip())
    # Fallback if text had no periods
    if not chunks:
        for i in range(0, len(text), max_chars):
            chunks.append(text[i:i+max_chars])
    return chunks


def summarize_long(text: str, target_words: int = 120):
    summarizer, _, _ = get_pipes()
    # Map rough word target to token lengths
    max_len = min(256, max(64, int(target_words * 1.6)))
    min_len = max(20, int(max_len * 0.4))
    pieces = []
    for ch in chunk_text(text, max_chars=1600):
        try:
            out = summarizer(ch, max_length=max_len, min_length=min_len, do_sample=False)
            pieces.append(out[0]["summary_text"])
        except Exception:
            # If the model complains about length, try a smaller window
            out = summarizer(ch[:1200], max_length=max_len, min_length=min_len, do_sample=False)
            pieces.append(out[0]["summary_text"])
    # If multiple pieces, do a second pass to fuse
    fused = " ".join(pieces)
    if len(pieces) > 1 and len(fused.split()) > target_words:
        out = summarizer(fused, max_length=max_len, min_length=min_len, do_sample=False)
        return out[0]["summary_text"].strip()
    return fused.strip()


def classify_topics(text: str, labels: list[str]):
    _, zshot, _ = get_pipes()
    res = zshot(text, candidate_labels=labels, multi_label=True)
    # Zip labels and scores, sort desc
    pairs = sorted(zip(res["labels"], res["scores"]), key=lambda x: x[1], reverse=True)
    top3 = pairs[:3]
    return pairs, top3


def analyze_sentiment(text: str):
    """3-class sentiment with chunk-aware averaging for long inputs."""
    _, _, sentiment = get_pipes()
    # Smaller chunk for sentiment; keep first few for speed
    s_chunks = chunk_text(text, max_chars=300) or [text[:300]]
    s_chunks = s_chunks[:8]

    agg = {"NEGATIVE": 0.0, "NEUTRAL": 0.0, "POSITIVE": 0.0}
    for ch in s_chunks:
        scores = sentiment(ch, return_all_scores=True)[0]
        for s in scores:
            agg[s["label"].upper()] += float(s["score"])
    n = float(len(s_chunks))
    for k in agg:
        agg[k] /= n

    label = max(agg, key=agg.get)
    score = agg[label]
    return label, score


def analyze(text, labels_csv, summary_words):
    text = (text or "").strip()
    if not text:
        return (
            "",   # summary
            [],   # table rows
            "",   # top topics string
            "",   # sentiment label
            0.0,  # sentiment score
        )

    # Prepare labels (CSV → list)
    labels_csv = (labels_csv or "").strip()
    labels = [l.strip() for l in labels_csv.split(",") if l.strip()] or DEFAULT_LABELS

    summary = summarize_long(text, target_words=int(summary_words))
    pairs, top3 = classify_topics(text, labels)
    sent_label, sent_score = analyze_sentiment(text)

    # Build a friendly top-topics string
    top_str = ", ".join([f"{lab} ({score:.2f})" for lab, score in top3]) if top3 else ""

    # Convert for Dataframe: list[list]
    table_rows = [[lab, round(score, 4)] for lab, score in pairs]

    return summary, table_rows, top_str, sent_label, sent_score


with gr.Blocks(title="TriScope — Text Insight Stack", css="""
:root{--radius:16px}
.header {font-size: 28px; font-weight: 800;}
.subtle {opacity:.8}
.card {border:1px solid #e5e7eb; border-radius: var(--radius); padding:16px}
""") as demo:
    gr.Markdown("""
    <div class="header">🧠 TriScope — Text Insight Stack</div>
    <div class="subtle">Summarize • Topic Classify • Sentiment — powered by three open models on Hugging Face</div>
    """)

    with gr.Row():
        with gr.Column(scale=5):
            txt = gr.Textbox(
                label="Paste text",
                placeholder="Paste any article, JD, email, or paragraph...",
                lines=12,
                elem_classes=["card"],
            )
            labels = gr.Textbox(
                label="Candidate topic labels (comma-separated)",
                value=", ".join(DEFAULT_LABELS),
                elem_classes=["card"],
            )
            words = gr.Slider(
                minimum=40, maximum=200, value=120, step=10,
                label="Target summary length (words)",
                elem_classes=["card"],
            )
            run = gr.Button("Analyze", variant="primary")

        with gr.Column(scale=5):
            with gr.Tab("Summary"):
                out_summary = gr.Markdown()
            with gr.Tab("Topics"):
                out_table = gr.Dataframe(headers=["label", "score"], datatype=["str", "number"], interactive=False)
                out_top = gr.Markdown()
            with gr.Tab("Sentiment"):
                # Show 3 classes
                out_sent_label = gr.Label(num_top_classes=3)
                out_sent_score = gr.Number(label="Confidence score")

    gr.Examples(
        label="Try an example",
        examples=[[
            "Open-source models are transforming AI by enabling broad access to powerful capabilities. However, organizations must balance innovation with governance, ensuring that safety and compliance keep pace with deployment. This article explores how companies can adopt a pragmatic approach to evaluation, monitoring, and human oversight while still benefiting from the speed of open development."
        ]],
        inputs=[txt]
    )

    run.click(
        analyze,
        inputs=[txt, labels, words],
        outputs=[out_summary, out_table, out_top, out_sent_label, out_sent_score]
    )

if __name__ == "__main__":
    # Helpful for Spaces; enables logs and proper binding
    demo.launch(server_name="0.0.0.0", server_port=7860, debug=True)