Spaces:
Runtime error
Runtime error
| import warnings | |
| from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer | |
| from happytransformer import HappyTextToText, TTSettings | |
| from styleformer import Styleformer | |
| from sentence_transformers import SentenceTransformer | |
| import chromadb | |
| import pandas as pd | |
| import logging | |
| import re | |
| from threading import Thread | |
| import hashlib | |
| import diskcache as dc | |
| import nltk | |
| nltk.download('punkt_tab') | |
| warnings.filterwarnings("ignore") | |
| logging.basicConfig(level=logging.INFO, # filename="py_log.log",filemode="w", | |
| format="%(asctime)s %(levelname)s %(message)s", datefmt="%Y-%m-%d %H:%M:%S") | |
| # For chromadb collection | |
| MAX_TOKENS = 512 | |
| client = chromadb.Client() | |
| embedder = SentenceTransformer('all-MiniLM-L6-v2') | |
| collection_name = 'papers' | |
| # For grammar checker | |
| happy_tt = HappyTextToText("T5", "vennify/t5-base-grammar-correction") | |
| grammar_cache = dc.Cache('grammar_cache') | |
| # For academic style checks | |
| sf = Styleformer(style=0) | |
| style_cache = dc.Cache('style_cache') | |
| # For text generation | |
| model_name = "Qwen/Qwen2.5-1.5B-Instruct" | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_name, | |
| torch_dtype="auto", | |
| device_map="auto" | |
| ) | |
| model.generation_config.max_new_tokens = 2048 | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| model_cache = dc.Cache('model_cache') | |
| def generate_key(text): | |
| return hashlib.md5(text.encode()).hexdigest() | |
| def split_into_chunks(text, max_tokens=MAX_TOKENS): | |
| sentences = nltk.sent_tokenize(text) | |
| chunks, current = [], "" | |
| current_tokens = 0 | |
| for sentence in sentences: | |
| sentence_tokens = len(sentence.split()) | |
| if current_tokens + sentence_tokens <= max_tokens: | |
| current += sentence + ' ' | |
| current_tokens += sentence_tokens | |
| else: | |
| chunks.append(current.strip()) | |
| current, current_tokens = sentence + ' ', sentence_tokens | |
| if current: | |
| chunks.append(current.strip()) | |
| return chunks | |
| # def split_into_chunks(text, max_tokens=MAX_TOKENS): | |
| # sentences = text.split(". ") | |
| # chunks = [] | |
| # current = "" | |
| # for sentence in sentences: | |
| # if len(current.split()) + len(sentence.split()) <= max_tokens: | |
| # current += sentence + '. ' | |
| # else: | |
| # chunks.append(current.strip()) | |
| # current = sentence + '. ' | |
| # if current: | |
| # chunks.append(current.strip()) | |
| # return chunks | |
| def clean_text(text): | |
| # Remove newlines within sentences but keep paragraph breaks | |
| text = re.sub(r'\n(?!\n)', ' ', text) | |
| # Remove multiple newlines, keeping only double newlines for paragraphs | |
| text = re.sub(r'\n{2,}', '\n\n', text) | |
| # Rejoin hyphenated words split across lines | |
| text = re.sub(r'(\w)-\s+(\w)', r'\1\2', text) | |
| # Remove citation brackets and figure numbers | |
| text = re.sub(r'\[\d+\]', '', text) # Removes [7], [6], etc. | |
| text = re.sub(r'Fig\.|Figure', '', text) # Removes "Fig." or "Figure" references | |
| # Strip leading/trailing spaces from each paragraph | |
| paragraphs = text.split('\n') | |
| cleaned_paragraphs = [para.strip() for para in paragraphs if para.strip()] | |
| # Join cleaned paragraphs back with double newlines for readability | |
| cleaned_text = '\n\n'.join(cleaned_paragraphs) | |
| return cleaned_text | |
| def get_collection() -> chromadb.Collection: | |
| collection_names = [collection.name for collection in client.list_collections()] | |
| logging.info(f"Client collection names: {collection_names}") | |
| if collection_name not in collection_names: | |
| logging.info(f"Creation of a collection...") | |
| collection = client.create_collection(name=collection_name) | |
| papers = pd.read_csv("hf://datasets/somosnlp-hackathon-2022/scientific_papers_en/scientific_paper_en.csv") | |
| logging.info(f"The data downloaded from url.") | |
| papers = papers.drop(['id'], axis=1) | |
| papers = papers.iloc[:200] | |
| for i in range(200): | |
| paper = papers.iloc[i] | |
| idx = paper.name | |
| full_text = clean_text('Abstract ' + paper['abstract'] + ' ' + paper['text_no_abstract']) | |
| chunks = split_into_chunks(full_text) | |
| for id, chunk in enumerate(chunks): | |
| embeddings = embedder.encode([chunk]) | |
| collection.upsert(ids=f"paper{idx}_chunk_{id}", | |
| documents=[chunk], | |
| embeddings=embeddings,) | |
| logging.info(f"Collection upsert: The content of paper_{idx} was chunked and collected in vector db!") | |
| logging.info(f"Collection is filled!\n") | |
| else: | |
| collection = client.get_collection(name=collection_name) | |
| logging.info(f"Collection '{collection_name}' already exists!") | |
| return collection | |
| def fix_grammar(text: str): | |
| logging.info(f"\n---Fix Grammar input:---\n{text}") | |
| key = generate_key(text) | |
| if key in grammar_cache: | |
| logging.info(f"Similar request was found in 'grammar_cache' and retrieved from it!") | |
| yield grammar_cache[key] | |
| else: | |
| args = TTSettings(num_beams=5, min_length=1) | |
| chunks = split_into_chunks(text=text, max_tokens=40) | |
| corrected_text = "" | |
| error_flag = False | |
| for chunk in chunks: | |
| try: | |
| result = happy_tt.generate_text(f"grammar: {chunk}", args=args) | |
| corrected_part = f"{result.text} " | |
| except Exception as e: | |
| error_flag = True | |
| logging.error(f"Error correcting grammar: {e}") | |
| corrected_part = f"{chunk} " | |
| corrected_text += corrected_part | |
| yield corrected_text | |
| if not error_flag: | |
| grammar_cache.set(key, corrected_text, expire=86400) | |
| logging.info(f"The result was cached in 'grammar_cache'!") | |
| def fix_academic_style(informal_text: str): | |
| logging.info(f"\n---Fix Academic Style input:---\n{informal_text}") | |
| key = generate_key(informal_text) | |
| if key in style_cache: | |
| logging.info(f"Similar request was found in 'style_cache' and retrieved from it!") | |
| yield style_cache[key] | |
| else: | |
| chunks = split_into_chunks(text=informal_text, max_tokens=25) | |
| formal_text = "" | |
| error_flag = False | |
| for chunk in chunks: | |
| try: | |
| corrected_part = sf.transfer(chunk) | |
| if corrected_part is None: | |
| error_flag = True | |
| corrected_part = f"{chunk} " | |
| logging.warning("---COULD NOT FIX ACADEMIC STYLE!\n") | |
| else: | |
| corrected_part = f"{corrected_part} " | |
| except Exception as e: | |
| error_flag = True | |
| logging.error(f"Error in academic style transformation: {e}") | |
| corrected_part = f"{chunk} " | |
| formal_text += corrected_part | |
| yield formal_text | |
| if not error_flag: | |
| style_cache.set(key, formal_text, expire=86400) | |
| logging.info(f"The result was cached in 'style_cache'!") | |
| def _chat_stream(initial_text: str, parts: list): | |
| logging.info(f"\n---Generate Article input:---\n{initial_text}") | |
| parts = ", ".join(parts).lower() | |
| for_cache = initial_text + ' ' + parts | |
| key = generate_key(for_cache) | |
| if key in model_cache: | |
| logging.info(f"Similar request was found in 'model_cache' and retrieved from it!") | |
| yield model_cache[key] | |
| else: | |
| text_embedding = embedder.encode([initial_text]) | |
| chroma_collection = get_collection() | |
| results = chroma_collection.query( | |
| query_embeddings=text_embedding, | |
| n_results=1 | |
| ) | |
| context = results['documents'][0] if results['documents'] else "" | |
| if context == "": | |
| logging.warning(f"COLLECTION QUERY: No context was found in the database!") | |
| messages = [ | |
| {"role": "system", "content": """You are helpful Academic Research Assistant which helps to generate | |
| necessary parts of the reserch based on the provided context. | |
| The context is the following: 'written text' - this is the text that user | |
| has for now and want to complete, 'parts' - those are the parts of paper | |
| user needs to complete (it could be the abstract, introduction, methodology, | |
| discussion, conclusion, or full text), 'context' - the similar article | |
| the structure of which can be used as a base for the text (it can be empty | |
| in case of absence of similar papers in the database.). The output should be | |
| only generated article (or parts of it). The responce must be provided as a | |
| raw text. Be precise and follow the structure of academic papers parts."""}, | |
| {"role": "user", "content": f"'written text': {initial_text}\n 'parts': {parts}\n 'context': {context}"}, | |
| ] | |
| input_text = tokenizer.apply_chat_template( | |
| messages, | |
| add_generation_prompt=True, | |
| tokenize=False, | |
| ) | |
| inputs = tokenizer([input_text], return_tensors="pt").to(model.device) | |
| streamer = TextIteratorStreamer( | |
| tokenizer=tokenizer, skip_prompt=True, timeout=160.0, skip_special_tokens=True | |
| ) | |
| generation_kwargs = { | |
| **inputs, | |
| "streamer": streamer, | |
| } | |
| thread = Thread(target=model.generate, kwargs=generation_kwargs) | |
| thread.start() | |
| response = "" | |
| for new_text in streamer: | |
| response += new_text | |
| yield response | |
| model_cache.set(key, response, expire=86400) | |
| logging.info(f"The result was cached in 'model_cache'!") | |
| def predict(goal: str, parts: list, context: str): | |
| if context == "": | |
| yield "Write your text first!" | |
| logging.info("No context was provided!") | |
| elif goal == 'Fix Academic Style': | |
| formal_text = "" | |
| try: | |
| for new_text in fix_academic_style(context): | |
| formal_text = new_text | |
| yield formal_text | |
| if not formal_text: | |
| yield "Generation failed or timed out. Please try again!" | |
| logging.info(f"\n---Academic style corrected:---\n {formal_text}\n") | |
| except Exception as e: | |
| logging.error(f"Error in 'Fix Academic Style' occured: {e}") | |
| yield "Try to wait a little bit and resend your request!" | |
| elif goal == 'Fix Grammar': | |
| try: | |
| full_response = "" | |
| for new_text in fix_grammar(context): | |
| full_response = new_text | |
| yield full_response | |
| if not full_response: | |
| yield "Generation failed or timed out. Please try again!" | |
| logging.info(f"\n---Grammar corrected:---\n{full_response}\n") | |
| except Exception as e: | |
| logging.error(f"Error in 'Fix Grammar' occured: {e}") | |
| yield "Try to wait a little bit and resend your request!" | |
| else: | |
| try: | |
| full_response = "" | |
| for new_text in _chat_stream(context, parts): | |
| full_response = new_text | |
| yield full_response | |
| if not full_response: | |
| yield "Generation failed or timed out. Please try again!" | |
| logging.info(f"\nThe text was generated!\n{full_response}") | |
| except Exception as e: | |
| logging.error(f"Error in 'Write Text' occured: {e}") | |
| yield "Try to wait a little bit and resend your request!" | |