VDNT11 commited on
Commit
1316bcd
·
verified ·
1 Parent(s): 758690d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -44
app.py CHANGED
@@ -2,7 +2,7 @@ import gradio as gr
2
  import torch
3
  from PIL import Image
4
  import os
5
- from transformers import BlipProcessor, BlipForConditionalGeneration, AutoModelForSeq2SeqLM, AutoTokenizer, VitsTokenizer, VitsModel, AutoModelForCausalLM, set_seed
6
  from IndicTransToolkit import IndicProcessor
7
  from gtts import gTTS
8
  import soundfile as sf
@@ -12,27 +12,15 @@ from langchain.text_splitter import RecursiveCharacterTextSplitter
12
  from langchain.docstore.document import Document
13
  import PyPDF2
14
  import tempfile
15
- from huggingface_hub import login
16
-
17
- # Authenticate with Hugging Face token
18
- if os.getenv("HF_TOKEN"):
19
- login(token=os.getenv("HF_TOKEN"))
20
- else:
21
- raise ValueError("HF_TOKEN environment variable not set. Please set it in Hugging Face Spaces settings.")
22
 
23
  # Initialize BLIP for image captioning
24
- blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
25
- blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large").to("cuda" if torch.cuda.is_available() else "cpu")
26
 
27
- # Initialize Mixtral-8x7B-Instruct for conversational tasks
28
- mixtral_model_name = "mistralai/Mixtral-8x7B-Instruct-v0.1"
29
- mixtral_tokenizer = AutoTokenizer.from_pretrained(mixtral_model_name)
30
- mixtral_model = AutoModelForCausalLM.from_pretrained(
31
- mixtral_model_name,
32
- load_in_4bit=True,
33
- device_map="auto",
34
- torch_dtype=torch.bfloat16
35
- )
36
 
37
  # Initialize vector store and embeddings for RAG
38
  embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
@@ -41,7 +29,7 @@ temp_dir = tempfile.mkdtemp()
41
 
42
  def generate_caption(image):
43
  image = image.convert("RGB")
44
- inputs = blip_processor(image, "image of", return_tensors="pt").to("cuda" if torch.cuda.is_available() else "cpu")
45
  with torch.no_grad():
46
  generated_ids = blip_model.generate(**inputs)
47
  caption = blip_processor.decode(generated_ids[0], skip_special_tokens=True)
@@ -54,13 +42,11 @@ def translate_caption(caption, target_languages):
54
  model_IT2 = torch.quantization.quantize_dynamic(model_IT2, {torch.nn.Linear}, dtype=torch.qint8)
55
  ip = IndicProcessor(inference=True)
56
  src_lang = "eng_Latn"
57
- DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
58
- model_IT2.to(DEVICE)
59
  input_sentences = [caption]
60
  translations = {}
61
  for tgt_lang in target_languages:
62
  batch = ip.preprocess_batch(input_sentences, src_lang=src_lang, tgt_lang=tgt_lang)
63
- inputs = tokenizer_IT2(batch, truncation=True, padding="longest", return_tensors="pt").to(DEVICE)
64
  with torch.no_grad():
65
  generated_tokens = model_IT2.generate(**inputs, use_cache=True, min_length=0, max_length=256, num_beams=5, num_return_sequences=1)
66
  with tokenizer_IT2.as_target_tokenizer():
@@ -75,18 +61,6 @@ def generate_audio_gtts(text, lang_code):
75
  tts.save(output_file)
76
  return output_file
77
 
78
- def generate_audio_fbmms(text, model_name):
79
- output_file = os.path.join(temp_dir, f"{model_name.split('/')[-1]}.wav")
80
- tokenizer = VitsTokenizer.from_pretrained(model_name)
81
- model = VitsModel.from_pretrained(model_name)
82
- inputs = tokenizer(text=text, return_tensors="pt")
83
- set_seed(555)
84
- with torch.no_grad():
85
- outputs = model(**inputs)
86
- waveform = outputs.waveform[0].cpu().numpy()
87
- sf.write(output_file, waveform, samplerate=model.config.sampling_rate)
88
- return output_file
89
-
90
  def process_document(file):
91
  global vector_store
92
  if file.name.endswith(".pdf"):
@@ -96,7 +70,7 @@ def process_document(file):
96
  text += page.extract_text() or ""
97
  else:
98
  text = file.read().decode("utf-8")
99
- text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
100
  chunks = text_splitter.split_text(text)
101
  documents = [Document(page_content=chunk) for chunk in chunks]
102
  vector_store = FAISS.from_documents(documents, embeddings)
@@ -107,13 +81,13 @@ def chat_with_llm(message, history):
107
  global vector_store
108
  context = ""
109
  if vector_store:
110
- docs = vector_store.similarity_search(message, k=3)
111
  context = "\n".join([doc.page_content for doc in docs])
112
- prompt = f"[INST] You are a helpful assistant. Use the following context to answer the question accurately:\n\n{context}\n\nQuestion: {message} [/INST]"
113
- inputs = mixtral_tokenizer(prompt, return_tensors="pt").to("cuda" if torch.cuda.is_available() else "cpu")
114
  with torch.no_grad():
115
- outputs = mixtral_model.generate(**inputs, max_length=1000, num_return_sequences=1, temperature=0.7)
116
- response = mixtral_tokenizer.decode(outputs[0], skip_special_tokens=True)
117
  return response.replace(prompt, "").strip()
118
 
119
  def image_tab(image, target_languages):
@@ -133,7 +107,7 @@ with gr.Blocks() as demo:
133
  with gr.Tabs():
134
  with gr.TabItem("Image Processing"):
135
  image_input = gr.Image(type="pil", label="Upload Image")
136
- lang_select = gr.CheckboxGroup(["hin_Deva", "mar_Deva", "guj_Gujr", "urd_Arab"], label="Select Target Languages", value=["hin_Deva", "mar_Deva"])
137
  process_btn = gr.Button("Process Image")
138
  caption_output = gr.Textbox(label="Generated Caption")
139
  translation_output = gr.JSON(label="Translations")
@@ -151,5 +125,4 @@ with gr.Blocks() as demo:
151
  msg.submit(chat_with_llm, inputs=[msg, chatbot], outputs=chatbot)
152
  clear.click(lambda: None, None, chatbot, queue=False)
153
 
154
- demo.launch()
155
-
 
2
  import torch
3
  from PIL import Image
4
  import os
5
+ from transformers import BlipProcessor, BlipForConditionalGeneration, AutoModelForSeq2SeqLM, AutoTokenizer, AutoModelForCausalLM
6
  from IndicTransToolkit import IndicProcessor
7
  from gtts import gTTS
8
  import soundfile as sf
 
12
  from langchain.docstore.document import Document
13
  import PyPDF2
14
  import tempfile
 
 
 
 
 
 
 
15
 
16
  # Initialize BLIP for image captioning
17
+ blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
18
+ blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
19
 
20
+ # Initialize Gemma-2B-Instruct for conversational tasks
21
+ gemma_model_name = "google/gemma-2b-it"
22
+ gemma_tokenizer = AutoTokenizer.from_pretrained(gemma_model_name)
23
+ gemma_model = AutoModelForCausalLM.from_pretrained(gemma_model_name)
 
 
 
 
 
24
 
25
  # Initialize vector store and embeddings for RAG
26
  embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
 
29
 
30
  def generate_caption(image):
31
  image = image.convert("RGB")
32
+ inputs = blip_processor(image, "image of", return_tensors="pt")
33
  with torch.no_grad():
34
  generated_ids = blip_model.generate(**inputs)
35
  caption = blip_processor.decode(generated_ids[0], skip_special_tokens=True)
 
42
  model_IT2 = torch.quantization.quantize_dynamic(model_IT2, {torch.nn.Linear}, dtype=torch.qint8)
43
  ip = IndicProcessor(inference=True)
44
  src_lang = "eng_Latn"
 
 
45
  input_sentences = [caption]
46
  translations = {}
47
  for tgt_lang in target_languages:
48
  batch = ip.preprocess_batch(input_sentences, src_lang=src_lang, tgt_lang=tgt_lang)
49
+ inputs = tokenizer_IT2(batch, truncation=True, padding="longest", return_tensors="pt")
50
  with torch.no_grad():
51
  generated_tokens = model_IT2.generate(**inputs, use_cache=True, min_length=0, max_length=256, num_beams=5, num_return_sequences=1)
52
  with tokenizer_IT2.as_target_tokenizer():
 
61
  tts.save(output_file)
62
  return output_file
63
 
 
 
 
 
 
 
 
 
 
 
 
 
64
  def process_document(file):
65
  global vector_store
66
  if file.name.endswith(".pdf"):
 
70
  text += page.extract_text() or ""
71
  else:
72
  text = file.read().decode("utf-8")
73
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=100)
74
  chunks = text_splitter.split_text(text)
75
  documents = [Document(page_content=chunk) for chunk in chunks]
76
  vector_store = FAISS.from_documents(documents, embeddings)
 
81
  global vector_store
82
  context = ""
83
  if vector_store:
84
+ docs = vector_store.similarity_search(message, k=2)
85
  context = "\n".join([doc.page_content for doc in docs])
86
+ prompt = f"<start_of_turn>user\nYou are a helpful assistant. Use the following context to answer the question accurately:\n\n{context}\n\nQuestion: {message}\n<end_of_turn>\n<start_of_turn>assistant"
87
+ inputs = gemma_tokenizer(prompt, return_tensors="pt")
88
  with torch.no_grad():
89
+ outputs = gemma_model.generate(**inputs, max_length=500, num_return_sequences=1, temperature=0.7)
90
+ response = gemma_tokenizer.decode(outputs[0], skip_special_tokens=True)
91
  return response.replace(prompt, "").strip()
92
 
93
  def image_tab(image, target_languages):
 
107
  with gr.Tabs():
108
  with gr.TabItem("Image Processing"):
109
  image_input = gr.Image(type="pil", label="Upload Image")
110
+ lang_select = gr.CheckboxGroup(["hin_Deva", "guj_Gujr", "urd_Arab"], label="Select Target Languages", value=["hin_Deva"])
111
  process_btn = gr.Button("Process Image")
112
  caption_output = gr.Textbox(label="Generated Caption")
113
  translation_output = gr.JSON(label="Translations")
 
125
  msg.submit(chat_with_llm, inputs=[msg, chatbot], outputs=chatbot)
126
  clear.click(lambda: None, None, chatbot, queue=False)
127
 
128
+ demo.launch()