Spaces:
Runtime error
Runtime error
| import torch | |
| from transformers import AutoTokenizer, AutoModelForTokenClassification, pipeline | |
| import requests | |
| import tqdm as t | |
| import re | |
| from sentence_transformers import SentenceTransformer | |
| from sklearn.metrics.pairwise import cosine_similarity | |
| import pytesseract | |
| from PIL import Image | |
| from collections import deque | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| tokenizer = AutoTokenizer.from_pretrained("dslim/bert-base-NER") | |
| model = AutoModelForTokenClassification.from_pretrained( | |
| "dslim/bert-base-NER") | |
| summarizer = pipeline( | |
| "summarization", model="facebook/bart-large-cnn", device=device) | |
| qa = pipeline("question-answering", | |
| model="deepset/roberta-base-squad2", device=device) | |
| def extract_text(image): | |
| """ | |
| Extracts text from an image using OCR. | |
| Args: | |
| image (PIL.Image.Image): Input image. | |
| Returns: | |
| dict: Extracted text with confidence and coordinates. | |
| Raises: | |
| ValueError: If the input image is not a PIL Image object. | |
| """ | |
| result = pytesseract.image_to_data(image, output_type='dict') | |
| n_boxes = len(result['level']) | |
| data = {} | |
| k = 0 | |
| for i in range(n_boxes): | |
| if result['conf'][i] >= 0.3 and result['text'][i] != '' and result['conf'][i] != -1: | |
| data[k] = {} | |
| (x, y, w, h) = (result['left'][i], result['top'] | |
| [i], result['width'][i], result['height'][i]) | |
| data[k]["coordinates"] = (x, y, w, h) | |
| text, conf = result['text'][k], result['conf'][k] | |
| data[k]["text"] = text | |
| data[k]["conf"] = conf | |
| k += 1 | |
| return data | |
| def strong_entities(question): | |
| nlp = pipeline("ner", model=model, tokenizer=tokenizer) | |
| ner_results = nlp(question) | |
| search_terms = [] | |
| current_term = "" | |
| for token in ner_results: | |
| if token["score"] >= 0.99: | |
| current_term += " " + token["word"] | |
| else: | |
| if current_term: | |
| search_terms.append(current_term.strip()) | |
| current_term = "" | |
| search_terms.append(token["word"]) | |
| if current_term: | |
| search_terms.append(current_term.strip()) | |
| print(search_terms[0].split()) | |
| return search_terms[0].split() | |
| def wiki_search(question): | |
| search_terms = strong_entities(question) | |
| URL = "https://en.wikipedia.org/w/api.php" | |
| corpus = [] | |
| for term in set(search_terms): # Removing duplicates | |
| SEARCHPAGE = term | |
| params = { | |
| "action": "query", | |
| "format": "json", | |
| "titles": SEARCHPAGE, | |
| "prop": "extracts", | |
| "explaintext": True | |
| } | |
| response = requests.get(URL, params=params) | |
| try: | |
| if response.status_code == 200: | |
| data = response.json() | |
| for page_id, page_data in t.tqdm(data["query"]["pages"].items()): | |
| if "extract" in page_data: # Check if extract exists | |
| corpus.append(page_data["extract"]) | |
| else: | |
| print("Failed to retrieve data:", response.status_code) | |
| except Exception as e: | |
| print("Failed to retrieve data:", e) | |
| final_corpus = [] | |
| for text in corpus: | |
| sections = re.split("\n\n\n== |==\n\n", text) | |
| for section in sections: | |
| if len(section.split()) >= 5: | |
| final_corpus.append(section) | |
| return " ".join(final_corpus[0:1]) | |
| def semantic_search(corpus, question): | |
| model = SentenceTransformer("all-MiniLM-L6-v2") | |
| question_embedding = model.encode(question) | |
| max_similarity = -1 | |
| most_similar_doc = None | |
| print(type(corpus[0])) | |
| print(corpus) | |
| for doc in t.tqdm(corpus): | |
| if len(doc.split()) >= 130: | |
| doc_summary = summarizer( | |
| doc, max_length=130, min_length=30, do_sample=False) | |
| if len(doc_summary) > 0 and "summary_text" in doc_summary[0]: | |
| summarized_doc = doc_summary[0]["summary_text"] | |
| else: | |
| summarized_doc = doc | |
| else: | |
| summarized_doc = doc | |
| doc_embedding = model.encode(summarized_doc) | |
| similarity = cosine_similarity( | |
| [question_embedding], [doc_embedding])[0][0] | |
| if similarity > max_similarity: | |
| max_similarity = similarity | |
| most_similar_doc = summarized_doc | |
| return most_similar_doc, similarity | |
| def dm(q, a, corpus, new_q, max_history_size=5): | |
| history = deque(maxlen=max_history_size) | |
| history.append({"question": q, "answer": a, "corpus": corpus}) | |
| best_corpus_index = None | |
| max_similarity = -1 | |
| for i in range(len(history)): | |
| _, q_similarity = semantic_search([history[i]["corpus"]], new_q) | |
| _, a_similarity = semantic_search( | |
| [history[i]["corpus"]], history[i]["answer"]) | |
| similarity = max(q_similarity, a_similarity) | |
| if similarity > max_similarity: | |
| max_similarity = similarity | |
| best_corpus_index = i | |
| if best_corpus_index is not None: | |
| return history[best_corpus_index]["corpus"] | |
| else: | |
| return corpus | |
| def first_corp(data, question, botton=False): | |
| if botton: | |
| corpus = wiki_search(question) | |
| texts = [data[i]["text"] for i in range(len(data))] | |
| text = " ".join(texts) | |
| corpus = [cp + " " + text for cp in corpus] | |
| else: | |
| texts = [data[i]["text"] for i in range(len(data))] | |
| text = " ".join(texts) | |
| corpus = [text] | |
| return " ".join(corpus) | |
| def Qa(image, new_q, internet_access=False): | |
| old_q = ["how are you?"] | |
| old_a = ["I am fine, thank you."] | |
| im_text = extract_text(image) | |
| if im_text: # Check if text is extracted | |
| old_corpus = [first_corp(im_text, old_q[-1], botton=internet_access)] | |
| else: | |
| old_corpus = None | |
| if internet_access: | |
| if not old_corpus: | |
| # Pass None as corpus to trigger internet access | |
| corpus = dm(old_q[-1], old_a[-1], None, new_q) | |
| else: | |
| # Pass old_corpus for internet access | |
| corpus = dm(old_q[-1], old_a[-1], old_corpus, new_q) | |
| else: | |
| corpus = old_corpus[0] if old_corpus else None | |
| a = qa(question=new_q, context=corpus) | |
| old_q.append(new_q) | |
| old_a.append(a["answer"]) | |
| old_corpus.append(corpus) | |
| old_conversations = "\n".join( | |
| f"Q: {q}\nA: {a}" for q, a in zip(old_q, old_a)) | |
| return a["answer"], old_conversations | |