Spaces:
Running
on
Zero
Running
on
Zero
| import datetime | |
| import os | |
| from collections import OrderedDict | |
| from typing import Any | |
| import gradio as gr | |
| import spaces | |
| import torch | |
| from transformers import ( | |
| AutoModelForCausalLM, | |
| AutoTokenizer, | |
| GenerationConfig, | |
| LogitsProcessorList, | |
| TextStreamer, | |
| ) | |
| from cache_system import CacheHandler | |
| from download_url import download_text_and_title | |
| from prompts import ( | |
| summarize_clickbait_large_prompt, | |
| summarize_clickbait_short_prompt, | |
| summarize_prompt, | |
| ) | |
| from utils import StopAfterTokenIsGenerated | |
| auth_token = os.environ.get("TOKEN") or True | |
| total_runs = 0 | |
| tokenizer = AutoTokenizer.from_pretrained("Iker/ClickbaitFighter-10B-pro") | |
| model = AutoModelForCausalLM.from_pretrained( | |
| "Iker/ClickbaitFighter-10B-pro", | |
| torch_dtype=torch.bfloat16, | |
| device_map="auto", | |
| # quantization_config=BitsAndBytesConfig( | |
| # load_in_4bit=True, | |
| # bnb_4bit_compute_dtype=torch.bfloat16, | |
| # bnb_4bit_use_double_quant=True, | |
| # ), | |
| # attn_implementation="flash_attention_2", | |
| ) | |
| generation_config = GenerationConfig( | |
| max_new_tokens=256, # Los resúmenes son cortos, no necesitamos más tokens | |
| min_new_tokens=1, # No queremos resúmenes vacíos | |
| do_sample=True, # Un poquito mejor que greedy sampling | |
| num_beams=1, | |
| use_cache=True, # Eficiencia | |
| top_k=40, | |
| top_p=0.1, | |
| repetition_penalty=1.1, # Ayuda a evitar que el modelo entre en bucles | |
| encoder_repetition_penalty=1.1, # Favorecemos que el modelo cite el texto original | |
| temperature=0.15, # temperature baja para evitar que el modelo genere texto muy creativo. | |
| eos_token_id=tokenizer.eos_token_id, | |
| pad_token_id=tokenizer.eos_token_id, | |
| ) | |
| stop_words = [ | |
| "<s>", | |
| "</s>", | |
| "\\n", | |
| "[/INST]", | |
| "[INST]", | |
| "### User:", | |
| "### Assistant:", | |
| "###", | |
| "<start_of_turn>", | |
| "<end_of_turn>", | |
| "<end_of_turn>\\n", | |
| "<eos>", | |
| "<|im_end|>", | |
| ] | |
| stop_criteria = LogitsProcessorList( | |
| [ | |
| StopAfterTokenIsGenerated( | |
| stops=[ | |
| torch.tensor(tokenizer.encode(stop_word, add_special_tokens=False)) | |
| for stop_word in stop_words.copy() | |
| ], | |
| eos_token_id=tokenizer.eos_token_id, | |
| ) | |
| ] | |
| ) | |
| class HuggingFaceDatasetSaver_custom(gr.HuggingFaceDatasetSaver): | |
| def _deserialize_components( | |
| self, | |
| data_dir, | |
| flag_data: list[Any], | |
| flag_option: str = "", | |
| username: str = "", | |
| ) -> tuple[dict[Any, Any], list[Any]]: | |
| """Deserialize components and return the corresponding row for the flagged sample. | |
| Images/audio are saved to disk as individual files. | |
| """ | |
| # Generate the row corresponding to the flagged sample | |
| features = OrderedDict() | |
| row = [] | |
| for component, sample in zip(self.components, flag_data): | |
| label = component.label or "" | |
| features[label] = {"dtype": "string", "_type": "Value"} | |
| row.append(sample) | |
| features["flag"] = {"dtype": "string", "_type": "Value"} | |
| features["username"] = {"dtype": "string", "_type": "Value"} | |
| row.append(flag_option) | |
| row.append(username) | |
| return features, row | |
| def finish_generation(text: str) -> str: | |
| return f"{text}\n\n⬇️ Ayuda a mejorar la herramienta marcando si el resumen es correcto o no.⬇️" | |
| def run_model(mode, title, text): | |
| if mode == 0: | |
| prompt = summarize_prompt(title, text) | |
| elif mode == 50: | |
| prompt = summarize_clickbait_large_prompt(title, text) | |
| elif mode == 100: | |
| prompt = summarize_clickbait_short_prompt(title, text) | |
| else: | |
| raise ValueError("Mode not supported") | |
| formatted_prompt = tokenizer.apply_chat_template( | |
| [{"role": "user", "content": prompt}], | |
| tokenize=False, | |
| add_generation_prompt=True, | |
| ) | |
| model_inputs = tokenizer( | |
| [formatted_prompt], return_tensors="pt", add_special_tokens=False | |
| ) | |
| streamer = TextStreamer( | |
| tokenizer=tokenizer, skip_prompt=True, skip_special_tokens=True | |
| ) | |
| model_output = model.generate( | |
| **model_inputs.to(model.device), | |
| streamer=streamer, | |
| generation_config=generation_config, | |
| logits_processor=stop_criteria, | |
| ) | |
| # yield streamer # Does not work properly on Zero environment | |
| temp = tokenizer.batch_decode( | |
| model_output[:, model_inputs["input_ids"].shape[-1] :], | |
| skip_special_tokens=True, | |
| clean_up_tokenization_spaces=True, | |
| )[0] | |
| return temp | |
| def generate_text( | |
| url: str, mode: int, progress=gr.Progress(track_tqdm=False) | |
| ) -> (str, str): | |
| global cache_handler | |
| global total_runs | |
| total_runs += 1 | |
| print(f"Total runs: {total_runs}. Last run: {datetime.datetime.now()}") | |
| url = url.strip() | |
| if url.startswith("https://twitter.com/") or url.startswith("https://x.com/"): | |
| yield ( | |
| "🤖 Vaya, parece que has introducido la url de un tweet. No puedo acceder a tweets, tienes que introducir la URL de una noticia.", | |
| "❌❌❌ Si el tweet contiene una noticia, dame la URL de la noticia ❌❌❌", | |
| "Error", | |
| ) | |
| return ( | |
| "🤖 Vaya, parece que has introducido la url de un tweet. No puedo acceder a tweets, tienes que introducir la URL de una noticia.", | |
| "❌❌❌ Si el tweet contiene una noticia, dame la URL de la noticia ❌❌❌", | |
| "Error", | |
| ) | |
| # 1) Download the article | |
| progress(0, desc="🤖 Accediendo a la noticia") | |
| # First, check if the URL is in the cache | |
| title, text, temp = cache_handler.get_from_cache(url, mode) | |
| if title is not None and text is not None and temp is not None: | |
| temp = finish_generation(temp) | |
| yield title, temp, text | |
| return title, temp, text | |
| else: | |
| try: | |
| title, text, url = download_text_and_title(url) | |
| except Exception as e: | |
| print(e) | |
| title = None | |
| text = None | |
| if title is None or text is None: | |
| yield ( | |
| "🤖 No he podido acceder a la notica, asegurate que la URL es correcta y que es posible acceder a la noticia desde un navegador.", | |
| "❌❌❌ Inténtalo de nuevo ❌❌❌", | |
| "Error", | |
| ) | |
| return ( | |
| "🤖 No he podido acceder a la notica, asegurate que la URL es correcta y que es posible acceder a la noticia desde un navegador.", | |
| "❌❌❌ Inténtalo de nuevo ❌❌❌", | |
| "Error", | |
| ) | |
| # Test if the redirected and clean url is in the cache | |
| _, _, temp = cache_handler.get_from_cache(url, mode, second_try=True) | |
| if temp is not None: | |
| temp = finish_generation(temp) | |
| yield title, temp, text | |
| return title, temp, text | |
| progress(0.5, desc="🤖 Leyendo noticia") | |
| try: | |
| temp = run_model(mode, title, text) | |
| except Exception as e: | |
| print(e) | |
| yield ( | |
| "🤖 El servidor no se encuentra disponible.", | |
| "❌❌❌ Inténtalo de nuevo más tarde ❌❌❌", | |
| "Error", | |
| ) | |
| return ( | |
| "🤖 El servidor no se encuentra disponible.", | |
| "❌❌❌ Inténtalo de nuevo más tarde ❌❌❌", | |
| "Error", | |
| ) | |
| cache_handler.add_to_cache( | |
| url=url, title=title, text=text, summary_type=mode, summary=temp | |
| ) | |
| temp = finish_generation(temp) | |
| yield title, temp, text | |
| hits, misses, cache_len = cache_handler.get_cache_stats() | |
| print( | |
| f"Hits: {hits}, misses: {misses}, cache length: {cache_len}. Percent hits: {round(hits/(hits+misses)*100,2)}%." | |
| ) | |
| return title, temp, text | |
| cache_handler = CacheHandler(max_cache_size=1000) | |
| hf_writer = HuggingFaceDatasetSaver_custom( | |
| auth_token, "Iker/Clickbait-News", private=True, separate_dirs=False | |
| ) | |
| demo = gr.Interface( | |
| generate_text, | |
| inputs=[ | |
| gr.Textbox( | |
| label="🌐 URL de la noticia", | |
| info="Introduce la URL de la noticia que deseas resumir.", | |
| value="https://ikergarcia1996.github.io/Iker-Garcia-Ferrero/", | |
| interactive=True, | |
| ), | |
| gr.Slider( | |
| minimum=0, | |
| maximum=100, | |
| step=50, | |
| value=50, | |
| label="🎚️ Nivel de resumen", | |
| info="""¿Hasta qué punto quieres resumir la noticia? | |
| Si solo deseas un resumen, selecciona 0. | |
| Si buscas un resumen y desmontar el clickbait, elige 50. | |
| Para obtener solo la respuesta al clickbait, selecciona 100""", | |
| interactive=True, | |
| ), | |
| ], | |
| outputs=[ | |
| gr.Textbox( | |
| label="📰 Titular de la noticia", | |
| interactive=False, | |
| placeholder="Aquí aparecerá el título de la noticia", | |
| ), | |
| gr.Textbox( | |
| label="🗒️ Resumen", | |
| interactive=False, | |
| placeholder="Aquí aparecerá el resumen de la noticia.", | |
| ), | |
| gr.Textbox( | |
| label="Noticia completa", | |
| visible=False, | |
| render=False, | |
| interactive=False, | |
| placeholder="Aquí aparecerá el resumen de la noticia.", | |
| ), | |
| ], | |
| # title="⚔️ Clickbait Fighter! ⚔️", | |
| thumbnail="https://huggingface.co/spaces/Iker/ClickbaitFighter/resolve/main/logo2.png", | |
| theme="JohnSmith9982/small_and_pretty", | |
| description=""" | |
| <table> | |
| <tr> | |
| <td style="width:100%"><img src="https://huggingface.co/spaces/Iker/ClickbaitFighter/resolve/main/head.png" align="right" width="100%"> </td> | |
| </tr> | |
| </table> | |
| <p align="justify">Esta Inteligencia Artificial es capaz de generar un resumen de una sola frase que revela la verdad detrás de un titular sensacionalista o clickbait. Solo tienes que introducir la URL de la noticia. La IA accederá a la noticia, la leerá y en cuestión de segundos generará un resumen de una sola frase que revele la verdad detrás del titular.</p> | |
| 🎚 Ajusta el nivel de resumen con el control deslizante. Cuanto maś alto, más corto será el resumen. | |
| ⌚ La IA se encuentra corriendo en un hardware bastante modesto, debería tardar menos de 30 segundos en generar el resumen, pero si muchos usuarios usan la app a la vez, tendrás que esperar tu turno. | |
| 💸 Este es un projecto sin ánimo de lucro, no se genera ningún tipo de ingreso con esta app. Los datos, la IA y el código se publicarán para su uso en la investigación académica. No puedes usar esta app para ningún uso comercial. | |
| 🧪 El modelo se encuentra en fase de desarrollo, si quieres ayudar a mejorarlo puedes usar los botones 👍 y 👎 para valorar el resumen. ¡Gracias por tu ayuda!""", | |
| article="Esta Inteligencia Artificial ha sido generada por Iker García-Ferrero. Puedes saber más sobre mi trabajo en mi [página web](https://ikergarcia1996.github.io/Iker-Garcia-Ferrero/) o mi perfil de [X](https://twitter.com/iker_garciaf). Puedes ponerte en contacto conmigo a través de correo electrónico (ver web) y X.", | |
| cache_examples=False, | |
| allow_flagging="manual", | |
| flagging_options=[("👍", "correct"), ("👎", "incorrect")], | |
| flagging_callback=hf_writer, | |
| concurrency_limit=20, | |
| ) | |
| demo.queue(max_size=None) | |
| demo.launch(share=False) | |