File size: 7,520 Bytes
bda91f5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
import os
import json
import random
from pathlib import Path
from PIL import Image

from fastapi import FastAPI, Request, File, UploadFile
from fastapi.responses import HTMLResponse
from fastapi.staticfiles import StaticFiles
from fastapi.templating import Jinja2Templates
import shutil
import uuid

import torch
from transformers import VisionEncoderDecoderModel, AutoTokenizer, AutoImageProcessor
from peft import PeftModel

app = FastAPI()

app.mount("/static", StaticFiles(directory="app/static"), name="static")
templates = Jinja2Templates(directory="app/templates")

# --- Model Loading ---
MODEL_NAME = 'tuman/vit-rugpt2-image-captioning'
ADAPTER_DIR = 'letitbE/image2wiki-adapter'
DATA_ROOT = Path('.')

print("Loading base model...")
base_model = VisionEncoderDecoderModel.from_pretrained(MODEL_NAME)
print("Loading tokenizer and feature extractor...")
try:
    tok = AutoTokenizer.from_pretrained(ADAPTER_DIR)
except:
    tok = AutoTokenizer.from_pretrained(MODEL_NAME)
fe = AutoImageProcessor.from_pretrained(MODEL_NAME)

print("Resizing embeddings...")
base_model.decoder.resize_token_embeddings(len(tok))

print("Loading LoRA adapter...")
try:
    base_model.decoder = PeftModel.from_pretrained(base_model.decoder, ADAPTER_DIR)
except Exception as e:
    print(f"WARNING: Could not load adapter {ADAPTER_DIR}: {e}. Using base model.")
base_model.eval()

# --- Data Loading ---
print("Loading dataset...")
valid_arts = []
try:
    with open('data/metadata.jsonl', 'r', encoding='utf-8') as f:
        for line in f:
            art = json.loads(line)
            img_path = Path(art.get('image_path', ''))
            if not img_path.is_absolute():
                img_path = DATA_ROOT / img_path
            if img_path.exists():
                valid_arts.append(art)
    print(f"Found {len(valid_arts)} valid articles with images.")
except Exception as e:
    print(f"Error loading dataset: {e}")

def build_target(article):
    # Same logic as in finetune.ipynb
    parts = []
    if article.get('title'):
        parts.append(f"<title>{article['title']}")
    if article.get('lead'):
        parts.append(f"<lead>{article['lead']}")
    for sec in article.get('sections', []):
        if sec.get('title'):
            parts.append(f"<section>{sec['title']}")
        if sec.get('text'):
            parts.append(f"<paragraph>{sec['text']}")
    return "\n".join(parts)

def parse_generated_text(text):
    """Parses the raw generated text into HTML for the Wikipedia template."""
    title = "Сгенерированная статья"
    
    # Extract title if present
    if text.startswith('<title>'):
        parts = text.split('<title>', 1)[1]
        # Find next tag
        next_tag_idx = len(parts)
        for tag in ['<lead>', '<section>', '<paragraph>']:
            idx = parts.find(tag)
            if idx != -1 and idx < next_tag_idx:
                next_tag_idx = idx
        title = parts[:next_tag_idx].strip()
        text = parts[next_tag_idx:]
    elif '<lead>' in text:
        title = text.split('<lead>')[0].strip()
        text = '<lead>' + text.split('<lead>', 1)[1]

    # Replace tags with HTML
    html = text
    html = html.replace('<lead>', '<p>')
    html = html.replace('<paragraph>', '</p><p>')
    
    toc_items = []
    
    def section_replacer(match):
        content = match.group(1)
        # Split by first period or newline
        split_idx = len(content)
        period_idx = content.find('.')
        if period_idx != -1:
            split_idx = period_idx + 1
            
        heading = content[:split_idx].strip()
        rest = content[split_idx:].strip()
        
        sec_id = heading.replace(' ', '_').replace('"', '').replace("'", "")
        toc_items.append((sec_id, heading))
        
        res = f'</p><div class="mw-heading mw-heading2"><h2 id="{sec_id}">{heading}</h2></div>'
        if rest:
            res += f'<p>{rest}'
        return res

    import re
    html = re.sub(r'<section>(.*?)(?=<section>|<paragraph>|<lead>|$)', section_replacer, html, flags=re.DOTALL)
    
    # Clean up empty paragraphs
    html = html.replace('<p></p>', '')
    if not html.endswith('</p>'):
        html += '</p>'
        
    # Generate TOC HTML
    toc_html = ""
    for i, (sec_id, heading) in enumerate(toc_items, 1):
        toc_html += f'''
        <li id="toc-{sec_id}" class="vector-toc-list-item vector-toc-level-1">
            <a class="vector-toc-link" href="#{sec_id}">
                <div class="vector-toc-text">
                    <span class="vector-toc-numb">{i}</span>
                    <span>{heading}</span>
                </div>
            </a>
        </li>
        '''
        
    return title, html, toc_html

def generate_article_raw(image, model, tokenizer, feature_extractor):
    pixel_values = feature_extractor(images=image, return_tensors="pt").pixel_values
    
    with torch.no_grad():
        output_ids = model.generate(
            pixel_values,
            max_new_tokens=512,
            # Для локального тестирования на CPU отключаем beam search (num_beams=1)
            # чтобы генерация работала в разы быстрее
            num_beams=1,
            no_repeat_ngram_size=3,
            decoder_start_token_id=tokenizer.bos_token_id,
            eos_token_id=tokenizer.eos_token_id,
            pad_token_id=tokenizer.eos_token_id,
        )
    
    generated_text = tokenizer.decode(output_ids[0], skip_special_tokens=False)
    # Remove bos/eos
    generated_text = generated_text.replace(tokenizer.bos_token, '').replace(tokenizer.eos_token, '').strip()
    return generated_text

@app.get("/", response_class=HTMLResponse)
async def read_root(request: Request):
    return templates.TemplateResponse(request, "search.html", {})

@app.post("/generate", response_class=HTMLResponse)
async def generate_article(request: Request, image: UploadFile = File(...)):
    # Save the uploaded image temporarily
    upload_dir = Path("app/static/uploads")
    upload_dir.mkdir(parents=True, exist_ok=True)
    
    file_ext = image.filename.split('.')[-1] if '.' in image.filename else 'jpg'
    filename = f"{uuid.uuid4()}.{file_ext}"
    file_path = upload_dir / filename
    
    with open(file_path, "wb") as buffer:
        shutil.copyfileobj(image.file, buffer)
        
    # Open image for the model
    try:
        pil_image = Image.open(file_path).convert('RGB')
        
        # Generate text
        print(f"Generating article for {filename}...")
        generated_raw = generate_article_raw(pil_image, base_model, tok, fe)
        
        if not generated_raw.startswith('<title>'):
            generated_raw = "<title>" + generated_raw
            
        title, content_html, toc_html = parse_generated_text(generated_raw)
        
    except Exception as e:
        print(f"Error during generation: {e}")
        title = "Ошибка генерации"
        content_html = f"<p>Произошла ошибка при создании статьи: {str(e)}</p>"
        toc_html = ""
        
    img_url = f"/static/uploads/{filename}"
        
    return templates.TemplateResponse(request, "index.html", {
        "title": title,
        "content": content_html,
        "toc": toc_html,
        "image_url": img_url,
        "target_text": ""
    })

# Mount the root directory to serve images
import os

if os.path.exists("data"):
    app.mount("/data", StaticFiles(directory="data"), name="data")