File size: 3,215 Bytes
0108eb5
 
 
 
 
115b19d
0108eb5
 
115b19d
 
 
 
 
f201052
115b19d
 
 
 
 
 
 
 
 
 
 
0108eb5
 
 
 
 
 
 
 
 
 
115b19d
0108eb5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115b19d
 
0108eb5
115b19d
 
 
 
0108eb5
 
 
 
115b19d
0108eb5
 
115b19d
0108eb5
 
 
 
115b19d
0108eb5
 
115b19d
0108eb5
115b19d
 
0108eb5
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import gradio as gr
from transformers import AutoModel, AutoTokenizer
from sklearn.cluster import KMeans
from kneed import KneeLocator
import torch
import re
import json

# ===== Quote Extractor =====
def extract_dialogues_with_context_from_text(raw_text, context_window=1):
    lines = [line.strip() for line in raw_text.splitlines() if line.strip()]
    dialogue_data = []
    for i, line in enumerate(lines):
        quotes = re.findall(r'[“"]([^”"]+)[”"]', line)
        for quote in quotes:
            context_lines = lines[max(0, i - context_window): i] + lines[i+1: i+1 + context_window]
            context = " ".join(context_lines)
            dialogue_data.append({
                "quote": quote,
                "context": context,
                "line_index": i
            })
    return dialogue_data

# ===== Encoder =====
def encode_quote(context: str, dialogue: str, tokenizer, model) -> torch.Tensor:
    text = f"{context} [SEP] {dialogue}"
    inputs = tokenizer(
        text,
        return_tensors="pt",
        truncation=True,
        padding=True,
        max_length=512
    )
    outputs = model(**inputs)
    cls_embedding = outputs.last_hidden_state[:, 0, :]
    return cls_embedding.squeeze(0)

def load_encoder():
    tokenizer = AutoTokenizer.from_pretrained("aNameNobodyChose/quote-caster-encoder")
    model = AutoModel.from_pretrained("aNameNobodyChose/quote-caster-encoder")
    model.eval()
    return tokenizer, model

def embed_quotes(data, tokenizer, model):
    embeddings = []
    for ex in data:
        emb = encode_quote(ex["context"], ex["quote"], tokenizer, model)
        embeddings.append(emb)
    return torch.stack(embeddings)

def auto_k_via_elbow(embeddings, max_k=10):
    X = embeddings.detach().numpy()
    inertias = []
    for k in range(1, max_k + 1):
        kmeans = KMeans(n_clusters=k, random_state=42, n_init='auto')
        kmeans.fit(X)
        inertias.append(kmeans.inertia_)
    knee = KneeLocator(range(1, max_k + 1), inertias, curve="convex", direction="decreasing")
    return knee.knee or 2

# ===== Pipeline =====
def predict(story_text):
    try:
        data = extract_dialogues_with_context_from_text(story_text)
        if not data:
            return "❌ No quotes found in story. Make sure quotes are enclosed in double quotes (\")."

        tokenizer, model = load_encoder()
        embeddings = embed_quotes(data, tokenizer, model)
        k = auto_k_via_elbow(embeddings)
        labels = KMeans(n_clusters=k).fit_predict(embeddings.detach().numpy())

        for quote, cluster_id in zip(data, labels):
            quote["predicted_speaker"] = f"SPEAKER_{cluster_id}"

        return json.dumps(data, indent=2, ensure_ascii=False)
    except Exception as e:
        return f"❌ Error: {e}"

# ===== Gradio App =====
gr.Interface(
    fn=predict,
    inputs=gr.Textbox(lines=30, label="Paste full story text (with quotes in double-quotes)"),
    outputs="textbox",
    title="🗣️ QuoteCaster - Speaker Attribution from Raw Text",
    description="Paste a full story containing dialogue in double quotes. The model will extract, embed, and cluster quotes by speaker."
).launch(server_name="0.0.0.0", server_port=7860)