Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| from transformers import AutoModel, AutoTokenizer | |
| from sklearn.cluster import KMeans | |
| from kneed import KneeLocator | |
| import torch | |
| import re | |
| import json | |
| # ===== Quote Extractor ===== | |
| def extract_dialogues_with_context_from_text(raw_text, context_window=1): | |
| lines = [line.strip() for line in raw_text.splitlines() if line.strip()] | |
| dialogue_data = [] | |
| for i, line in enumerate(lines): | |
| quotes = re.findall(r'[β"]([^β"]+)[β"]', line) | |
| for quote in quotes: | |
| context_lines = lines[max(0, i - context_window): i] + lines[i+1: i+1 + context_window] | |
| context = " ".join(context_lines) | |
| dialogue_data.append({ | |
| "quote": quote, | |
| "context": context, | |
| "line_index": i | |
| }) | |
| return dialogue_data | |
| # ===== Encoder ===== | |
| def encode_quote(context: str, dialogue: str, tokenizer, model) -> torch.Tensor: | |
| text = f"{context} [SEP] {dialogue}" | |
| inputs = tokenizer( | |
| text, | |
| return_tensors="pt", | |
| truncation=True, | |
| padding=True, | |
| max_length=512 | |
| ) | |
| outputs = model(**inputs) | |
| cls_embedding = outputs.last_hidden_state[:, 0, :] | |
| return cls_embedding.squeeze(0) | |
| def load_encoder(): | |
| tokenizer = AutoTokenizer.from_pretrained("aNameNobodyChose/quote-caster-encoder") | |
| model = AutoModel.from_pretrained("aNameNobodyChose/quote-caster-encoder") | |
| model.eval() | |
| return tokenizer, model | |
| def embed_quotes(data, tokenizer, model): | |
| embeddings = [] | |
| for ex in data: | |
| emb = encode_quote(ex["context"], ex["quote"], tokenizer, model) | |
| embeddings.append(emb) | |
| return torch.stack(embeddings) | |
| def auto_k_via_elbow(embeddings, max_k=10): | |
| X = embeddings.detach().numpy() | |
| inertias = [] | |
| for k in range(1, max_k + 1): | |
| kmeans = KMeans(n_clusters=k, random_state=42, n_init='auto') | |
| kmeans.fit(X) | |
| inertias.append(kmeans.inertia_) | |
| knee = KneeLocator(range(1, max_k + 1), inertias, curve="convex", direction="decreasing") | |
| return knee.knee or 2 | |
| # ===== Pipeline ===== | |
| def predict(story_text): | |
| try: | |
| data = extract_dialogues_with_context_from_text(story_text) | |
| if not data: | |
| return "β No quotes found in story. Make sure quotes are enclosed in double quotes (\")." | |
| tokenizer, model = load_encoder() | |
| embeddings = embed_quotes(data, tokenizer, model) | |
| k = auto_k_via_elbow(embeddings) | |
| labels = KMeans(n_clusters=k).fit_predict(embeddings.detach().numpy()) | |
| for quote, cluster_id in zip(data, labels): | |
| quote["predicted_speaker"] = f"SPEAKER_{cluster_id}" | |
| return json.dumps(data, indent=2, ensure_ascii=False) | |
| except Exception as e: | |
| return f"β Error: {e}" | |
| # ===== Gradio App ===== | |
| gr.Interface( | |
| fn=predict, | |
| inputs=gr.Textbox(lines=30, label="Paste full story text (with quotes in double-quotes)"), | |
| outputs="textbox", | |
| title="π£οΈ QuoteCaster - Speaker Attribution from Raw Text", | |
| description="Paste a full story containing dialogue in double quotes. The model will extract, embed, and cluster quotes by speaker." | |
| ).launch(server_name="0.0.0.0", server_port=7860) | |