File size: 5,583 Bytes
6835659
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
from __future__ import annotations

import logging
from dataclasses import dataclass
from typing import Optional
import json
import re
import urllib.request

from src.exceptions import GenerationError

logger = logging.getLogger(__name__)


@dataclass(frozen=True)
class TextGenResult:
    text: str


def _sanitize_text(s: str) -> str:
    """Remove common failure patterns (echoing rules, bullets, repetitions)."""
    s = s.strip()

    # Remove markdown/bullets
    s = re.sub(r"^\s*[-*•]\s+", "", s, flags=re.MULTILINE)

    # Remove obvious meta/instruction echoes
    bad_patterns = [
        r"(?i)\blength\s*:\s*\d+\s*[-–]\s*\d+\s*sentences\b.*",
        r"(?i)\brules\s*:\b.*",
        r"(?i)\bno bullet points\b.*",
        r"(?i)\bno repetition\b.*",
        r"(?i)\bno meta commentary\b.*",
        r"(?i)\bdescribe only\b.*",
    ]
    for pat in bad_patterns:
        s = re.sub(pat, "", s).strip()

    # Collapse whitespace
    s = re.sub(r"\n{3,}", "\n\n", s)
    s = re.sub(r"[ \t]{2,}", " ", s)

    # If the model repeated the same line many times, de-dup consecutive duplicates
    lines = [ln.strip() for ln in s.splitlines() if ln.strip()]
    deduped = []
    for ln in lines:
        if not deduped or deduped[-1] != ln:
            deduped.append(ln)
    s = "\n".join(deduped).strip()

    return s


def _ollama_generate(
    prompt: str,
    model: str = "qwen2:7b",
    temperature: float = 0.7,
    top_p: float = 0.9,
    num_predict: int = 180,
    host: str = "http://localhost:11434",
) -> str:
    """
    Calls Ollama local server: POST /api/generate
    """
    url = f"{host.rstrip('/')}/api/generate"
    payload = {
        "model": model,
        "prompt": prompt,
        "stream": False,
        "options": {
            "temperature": temperature,
            "top_p": top_p,
            "num_predict": num_predict,
        },
    }

    req = urllib.request.Request(
        url,
        data=json.dumps(payload).encode("utf-8"),
        headers={"Content-Type": "application/json"},
        method="POST",
    )

    try:
        with urllib.request.urlopen(req, timeout=600) as resp:
            data = json.loads(resp.read().decode("utf-8"))
            text = data.get("response", "").strip()
            logger.debug("Ollama generated %d chars", len(text))
            return text
    except Exception as e:
        logger.error("Ollama call failed on %s: %s", host, e)
        raise GenerationError(
            f"Ollama call failed. Is Ollama running on {host}? Error: {e}",
            modality="text", backend=f"ollama/{model}",
        ) from e


class TextGenerator:
    """
    Option A (recommended): Ollama text generator (instruction-following).
    Falls back to HF pipeline if use_ollama=False.
    """

    def __init__(
        self,
        use_ollama: bool = True,
        ollama_model: str = "qwen2:7b",
        ollama_host: str = "http://localhost:11434",
        max_new_tokens: int = 160,
        hf_model_name: str = "gpt2",
    ):
        self.use_ollama = use_ollama
        self.ollama_model = ollama_model
        self.ollama_host = ollama_host
        self.max_new_tokens = max_new_tokens
        self.hf_model_name = hf_model_name

        self._hf_pipe = None
        if not self.use_ollama:
            from transformers import pipeline

            self._hf_pipe = pipeline("text-generation", model=self.hf_model_name)

    def generate(self, prompt: str, deterministic: bool = True) -> TextGenResult:
        # This is the IMPORTANT part: we wrap your plan_text with strict generation rules.
        wrapped_prompt = """You are a concise descriptive writer.

Write a literal description of the same scene. Follow these rules:
- Write 3 to 5 natural sentences.
- No bullet points, no numbered lists.
- No repetition.
- No meta commentary (do not mention rules, prompts, or constraints).
- Focus on concrete visual details AND the likely audio ambience.

SCENE PLAN:
"""
        wrapped_prompt = f"{wrapped_prompt}{prompt}\n\nNow write the description:\n"

        if self.use_ollama:
            raw = _ollama_generate(
                prompt=wrapped_prompt,
                model=self.ollama_model,
                host=self.ollama_host,
                temperature=0.0 if deterministic else 0.7,
                top_p=1.0 if deterministic else 0.9,
                num_predict=max(self.max_new_tokens, 120),
            )
            clean = _sanitize_text(raw)

            # Last safety: if it comes out empty, return raw (better than nothing)
            return TextGenResult(text=clean if clean else raw)

        # HF fallback
        outputs = self._hf_pipe(
            wrapped_prompt,
            max_new_tokens=self.max_new_tokens,
            do_sample=not deterministic,
            temperature=0.0 if deterministic else 0.9,
            top_p=1.0 if deterministic else 0.95,
            num_return_sequences=1,
        )
        text = outputs[0]["generated_text"]
        text = _sanitize_text(text)
        return TextGenResult(text=text)


def generate_text(
    prompt: str,
    use_ollama: bool = True,
    deterministic: bool = True,
    ollama_model: str = "qwen2:7b",
    ollama_host: str = "http://localhost:11434",
    max_new_tokens: int = 160,
    hf_model_name: str = "gpt2",
) -> str:
    generator = TextGenerator(
        use_ollama=use_ollama,
        ollama_model=ollama_model,
        ollama_host=ollama_host,
        max_new_tokens=max_new_tokens,
        hf_model_name=hf_model_name,
    )
    return generator.generate(prompt, deterministic=deterministic).text