Spaces:
Sleeping
Sleeping
File size: 7,520 Bytes
bda91f5 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 | import os
import json
import random
from pathlib import Path
from PIL import Image
from fastapi import FastAPI, Request, File, UploadFile
from fastapi.responses import HTMLResponse
from fastapi.staticfiles import StaticFiles
from fastapi.templating import Jinja2Templates
import shutil
import uuid
import torch
from transformers import VisionEncoderDecoderModel, AutoTokenizer, AutoImageProcessor
from peft import PeftModel
app = FastAPI()
app.mount("/static", StaticFiles(directory="app/static"), name="static")
templates = Jinja2Templates(directory="app/templates")
# --- Model Loading ---
MODEL_NAME = 'tuman/vit-rugpt2-image-captioning'
ADAPTER_DIR = 'letitbE/image2wiki-adapter'
DATA_ROOT = Path('.')
print("Loading base model...")
base_model = VisionEncoderDecoderModel.from_pretrained(MODEL_NAME)
print("Loading tokenizer and feature extractor...")
try:
tok = AutoTokenizer.from_pretrained(ADAPTER_DIR)
except:
tok = AutoTokenizer.from_pretrained(MODEL_NAME)
fe = AutoImageProcessor.from_pretrained(MODEL_NAME)
print("Resizing embeddings...")
base_model.decoder.resize_token_embeddings(len(tok))
print("Loading LoRA adapter...")
try:
base_model.decoder = PeftModel.from_pretrained(base_model.decoder, ADAPTER_DIR)
except Exception as e:
print(f"WARNING: Could not load adapter {ADAPTER_DIR}: {e}. Using base model.")
base_model.eval()
# --- Data Loading ---
print("Loading dataset...")
valid_arts = []
try:
with open('data/metadata.jsonl', 'r', encoding='utf-8') as f:
for line in f:
art = json.loads(line)
img_path = Path(art.get('image_path', ''))
if not img_path.is_absolute():
img_path = DATA_ROOT / img_path
if img_path.exists():
valid_arts.append(art)
print(f"Found {len(valid_arts)} valid articles with images.")
except Exception as e:
print(f"Error loading dataset: {e}")
def build_target(article):
# Same logic as in finetune.ipynb
parts = []
if article.get('title'):
parts.append(f"<title>{article['title']}")
if article.get('lead'):
parts.append(f"<lead>{article['lead']}")
for sec in article.get('sections', []):
if sec.get('title'):
parts.append(f"<section>{sec['title']}")
if sec.get('text'):
parts.append(f"<paragraph>{sec['text']}")
return "\n".join(parts)
def parse_generated_text(text):
"""Parses the raw generated text into HTML for the Wikipedia template."""
title = "Сгенерированная статья"
# Extract title if present
if text.startswith('<title>'):
parts = text.split('<title>', 1)[1]
# Find next tag
next_tag_idx = len(parts)
for tag in ['<lead>', '<section>', '<paragraph>']:
idx = parts.find(tag)
if idx != -1 and idx < next_tag_idx:
next_tag_idx = idx
title = parts[:next_tag_idx].strip()
text = parts[next_tag_idx:]
elif '<lead>' in text:
title = text.split('<lead>')[0].strip()
text = '<lead>' + text.split('<lead>', 1)[1]
# Replace tags with HTML
html = text
html = html.replace('<lead>', '<p>')
html = html.replace('<paragraph>', '</p><p>')
toc_items = []
def section_replacer(match):
content = match.group(1)
# Split by first period or newline
split_idx = len(content)
period_idx = content.find('.')
if period_idx != -1:
split_idx = period_idx + 1
heading = content[:split_idx].strip()
rest = content[split_idx:].strip()
sec_id = heading.replace(' ', '_').replace('"', '').replace("'", "")
toc_items.append((sec_id, heading))
res = f'</p><div class="mw-heading mw-heading2"><h2 id="{sec_id}">{heading}</h2></div>'
if rest:
res += f'<p>{rest}'
return res
import re
html = re.sub(r'<section>(.*?)(?=<section>|<paragraph>|<lead>|$)', section_replacer, html, flags=re.DOTALL)
# Clean up empty paragraphs
html = html.replace('<p></p>', '')
if not html.endswith('</p>'):
html += '</p>'
# Generate TOC HTML
toc_html = ""
for i, (sec_id, heading) in enumerate(toc_items, 1):
toc_html += f'''
<li id="toc-{sec_id}" class="vector-toc-list-item vector-toc-level-1">
<a class="vector-toc-link" href="#{sec_id}">
<div class="vector-toc-text">
<span class="vector-toc-numb">{i}</span>
<span>{heading}</span>
</div>
</a>
</li>
'''
return title, html, toc_html
def generate_article_raw(image, model, tokenizer, feature_extractor):
pixel_values = feature_extractor(images=image, return_tensors="pt").pixel_values
with torch.no_grad():
output_ids = model.generate(
pixel_values,
max_new_tokens=512,
# Для локального тестирования на CPU отключаем beam search (num_beams=1)
# чтобы генерация работала в разы быстрее
num_beams=1,
no_repeat_ngram_size=3,
decoder_start_token_id=tokenizer.bos_token_id,
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.eos_token_id,
)
generated_text = tokenizer.decode(output_ids[0], skip_special_tokens=False)
# Remove bos/eos
generated_text = generated_text.replace(tokenizer.bos_token, '').replace(tokenizer.eos_token, '').strip()
return generated_text
@app.get("/", response_class=HTMLResponse)
async def read_root(request: Request):
return templates.TemplateResponse(request, "search.html", {})
@app.post("/generate", response_class=HTMLResponse)
async def generate_article(request: Request, image: UploadFile = File(...)):
# Save the uploaded image temporarily
upload_dir = Path("app/static/uploads")
upload_dir.mkdir(parents=True, exist_ok=True)
file_ext = image.filename.split('.')[-1] if '.' in image.filename else 'jpg'
filename = f"{uuid.uuid4()}.{file_ext}"
file_path = upload_dir / filename
with open(file_path, "wb") as buffer:
shutil.copyfileobj(image.file, buffer)
# Open image for the model
try:
pil_image = Image.open(file_path).convert('RGB')
# Generate text
print(f"Generating article for {filename}...")
generated_raw = generate_article_raw(pil_image, base_model, tok, fe)
if not generated_raw.startswith('<title>'):
generated_raw = "<title>" + generated_raw
title, content_html, toc_html = parse_generated_text(generated_raw)
except Exception as e:
print(f"Error during generation: {e}")
title = "Ошибка генерации"
content_html = f"<p>Произошла ошибка при создании статьи: {str(e)}</p>"
toc_html = ""
img_url = f"/static/uploads/{filename}"
return templates.TemplateResponse(request, "index.html", {
"title": title,
"content": content_html,
"toc": toc_html,
"image_url": img_url,
"target_text": ""
})
# Mount the root directory to serve images
import os
if os.path.exists("data"):
app.mount("/data", StaticFiles(directory="data"), name="data")
|