import os import json import random from pathlib import Path from PIL import Image from fastapi import FastAPI, Request, File, UploadFile from fastapi.responses import HTMLResponse from fastapi.staticfiles import StaticFiles from fastapi.templating import Jinja2Templates import shutil import uuid import torch from transformers import VisionEncoderDecoderModel, AutoTokenizer, AutoImageProcessor from peft import PeftModel app = FastAPI() app.mount("/static", StaticFiles(directory="app/static"), name="static") templates = Jinja2Templates(directory="app/templates") # --- Model Loading --- MODEL_NAME = 'tuman/vit-rugpt2-image-captioning' ADAPTER_DIR = 'letitbE/image2wiki-adapter' DATA_ROOT = Path('.') print("Loading base model...") base_model = VisionEncoderDecoderModel.from_pretrained(MODEL_NAME) print("Loading tokenizer and feature extractor...") try: tok = AutoTokenizer.from_pretrained(ADAPTER_DIR) except: tok = AutoTokenizer.from_pretrained(MODEL_NAME) fe = AutoImageProcessor.from_pretrained(MODEL_NAME) print("Resizing embeddings...") base_model.decoder.resize_token_embeddings(len(tok)) print("Loading LoRA adapter...") try: base_model.decoder = PeftModel.from_pretrained(base_model.decoder, ADAPTER_DIR) except Exception as e: print(f"WARNING: Could not load adapter {ADAPTER_DIR}: {e}. Using base model.") base_model.eval() # --- Data Loading --- print("Loading dataset...") valid_arts = [] try: with open('data/metadata.jsonl', 'r', encoding='utf-8') as f: for line in f: art = json.loads(line) img_path = Path(art.get('image_path', '')) if not img_path.is_absolute(): img_path = DATA_ROOT / img_path if img_path.exists(): valid_arts.append(art) print(f"Found {len(valid_arts)} valid articles with images.") except Exception as e: print(f"Error loading dataset: {e}") def build_target(article): # Same logic as in finetune.ipynb parts = [] if article.get('title'): parts.append(f"{article['title']}") if article.get('lead'): parts.append(f"<lead>{article['lead']}") for sec in article.get('sections', []): if sec.get('title'): parts.append(f"<section>{sec['title']}") if sec.get('text'): parts.append(f"<paragraph>{sec['text']}") return "\n".join(parts) def parse_generated_text(text): """Parses the raw generated text into HTML for the Wikipedia template.""" title = "Сгенерированная статья" # Extract title if present if text.startswith('<title>'): parts = text.split('<title>', 1)[1] # Find next tag next_tag_idx = len(parts) for tag in ['<lead>', '<section>', '<paragraph>']: idx = parts.find(tag) if idx != -1 and idx < next_tag_idx: next_tag_idx = idx title = parts[:next_tag_idx].strip() text = parts[next_tag_idx:] elif '<lead>' in text: title = text.split('<lead>')[0].strip() text = '<lead>' + text.split('<lead>', 1)[1] # Replace tags with HTML html = text html = html.replace('<lead>', '<p>') html = html.replace('<paragraph>', '</p><p>') toc_items = [] def section_replacer(match): content = match.group(1) # Split by first period or newline split_idx = len(content) period_idx = content.find('.') if period_idx != -1: split_idx = period_idx + 1 heading = content[:split_idx].strip() rest = content[split_idx:].strip() sec_id = heading.replace(' ', '_').replace('"', '').replace("'", "") toc_items.append((sec_id, heading)) res = f'</p><div class="mw-heading mw-heading2"><h2 id="{sec_id}">{heading}</h2></div>' if rest: res += f'<p>{rest}' return res import re html = re.sub(r'<section>(.*?)(?=<section>|<paragraph>|<lead>|$)', section_replacer, html, flags=re.DOTALL) # Clean up empty paragraphs html = html.replace('<p></p>', '') if not html.endswith('</p>'): html += '</p>' # Generate TOC HTML toc_html = "" for i, (sec_id, heading) in enumerate(toc_items, 1): toc_html += f''' <li id="toc-{sec_id}" class="vector-toc-list-item vector-toc-level-1"> <a class="vector-toc-link" href="#{sec_id}"> <div class="vector-toc-text"> <span class="vector-toc-numb">{i}</span> <span>{heading}</span> </div> </a> </li> ''' return title, html, toc_html def generate_article_raw(image, model, tokenizer, feature_extractor): pixel_values = feature_extractor(images=image, return_tensors="pt").pixel_values with torch.no_grad(): output_ids = model.generate( pixel_values, max_new_tokens=512, # Для локального тестирования на CPU отключаем beam search (num_beams=1) # чтобы генерация работала в разы быстрее num_beams=1, no_repeat_ngram_size=3, decoder_start_token_id=tokenizer.bos_token_id, eos_token_id=tokenizer.eos_token_id, pad_token_id=tokenizer.eos_token_id, ) generated_text = tokenizer.decode(output_ids[0], skip_special_tokens=False) # Remove bos/eos generated_text = generated_text.replace(tokenizer.bos_token, '').replace(tokenizer.eos_token, '').strip() return generated_text @app.get("/", response_class=HTMLResponse) async def read_root(request: Request): return templates.TemplateResponse(request, "search.html", {}) @app.post("/generate", response_class=HTMLResponse) async def generate_article(request: Request, image: UploadFile = File(...)): # Save the uploaded image temporarily upload_dir = Path("app/static/uploads") upload_dir.mkdir(parents=True, exist_ok=True) file_ext = image.filename.split('.')[-1] if '.' in image.filename else 'jpg' filename = f"{uuid.uuid4()}.{file_ext}" file_path = upload_dir / filename with open(file_path, "wb") as buffer: shutil.copyfileobj(image.file, buffer) # Open image for the model try: pil_image = Image.open(file_path).convert('RGB') # Generate text print(f"Generating article for {filename}...") generated_raw = generate_article_raw(pil_image, base_model, tok, fe) if not generated_raw.startswith('<title>'): generated_raw = "<title>" + generated_raw title, content_html, toc_html = parse_generated_text(generated_raw) except Exception as e: print(f"Error during generation: {e}") title = "Ошибка генерации" content_html = f"<p>Произошла ошибка при создании статьи: {str(e)}</p>" toc_html = "" img_url = f"/static/uploads/{filename}" return templates.TemplateResponse(request, "index.html", { "title": title, "content": content_html, "toc": toc_html, "image_url": img_url, "target_text": "" }) # Mount the root directory to serve images import os if os.path.exists("data"): app.mount("/data", StaticFiles(directory="data"), name="data")