File size: 2,081 Bytes
13f0358
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import os
import gradio as gr
import torch
from transformers import pipeline

# ─── 1) SET UP PIPELINES ──────────────────────────────────────────────────────
# Sentiment analysis (local)
sentiment_pipe = pipeline("sentiment-analysis")

# Text generation chat (local GPT-2)
device = 0 if torch.cuda.is_available() else -1
chat_pipe = pipeline(
    "text-generation",
    model="distilgpt2",
    tokenizer="distilgpt2",
    device=device,
    max_new_tokens=100,
    do_sample=True,
    temperature=0.7
)

def respond(message, chat_history):
    """
    - If the user starts with "Sentiment:", run sentiment analysis on the rest.
    - Otherwise fall back to GPT-2 chat continuation.
    """
    if message.lower().startswith("sentiment:"):
        text = message[len("sentiment:"):].strip()
        result = sentiment_pipe(text)[0]
        label = result["label"]
        score = result["score"]
        reply = f"πŸ” Sentiment: **{label}** (score: {score:.3f})"
    else:
        # GPT-2 continuation
        out = chat_pipe(message)
        # the pipeline returns [{'generated_text': "..."}]
        reply = out[0]["generated_text"].strip()
    chat_history.append(("You", message))
    chat_history.append(("Bot", reply))
    return chat_history

# ─── 2) BUILD THE UI ───────────────────────────────────────────────────────────
with gr.Blocks() as demo:
    gr.Markdown("## 😊 Sentiment-&-Chat Bot\n"
                "_Type `Sentiment: <your text>` to analyze sentiment, or just chat!_")
    chat = gr.Chatbot()
    msg  = gr.Textbox(placeholder="Type here…", show_label=False)

    msg.submit(respond, [msg, chat], [chat])

# queue() creates the `/api/predict` endpoint Spaces needs
demo = demo.queue()

if __name__ == "__main__":
    port = int(os.environ.get("PORT", 7860))
    demo.launch(server_name="0.0.0.0", server_port=port)