chatbox / app.py
anaspro
updatE
a4a88c0
raw
history blame
6.11 kB
import gradio as gr
import cv2
import torch
from PIL import Image
from pathlib import Path
from threading import Thread
from transformers import AutoProcessor, AutoModelForImageTextToText, TextIteratorStreamer
import spaces
import time
import os
# model config - Single model: Shako v4
model_name = "anaspro/Shako-4B-it"
model_id = "anaspro/Shako-4B-it"
processor = AutoProcessor.from_pretrained(model_id)
model = AutoModelForImageTextToText.from_pretrained(
model_id,
device_map="auto",
torch_dtype=torch.bfloat16
).eval()
processor = AutoProcessor.from_pretrained(model_name, token=hf_token)
# I will add timestamp later
def extract_video_frames(video_path, num_frames=8):
cap = cv2.VideoCapture(video_path)
frames = []
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
step = max(total_frames // num_frames, 1)
for i in range(num_frames):
cap.set(cv2.CAP_PROP_POS_FRAMES, i * step)
ret, frame = cap.read()
if ret:
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
frames.append(Image.fromarray(frame))
cap.release()
return frames
def format_message(content, files):
message_content = []
if content:
parts = content.split('<image>')
for i, part in enumerate(parts):
if part.strip():
message_content.append({"type": "text", "text": part.strip()})
if i < len(parts) - 1 and files:
img = Image.open(files.pop(0))
message_content.append({"type": "image", "image": img})
for file in files:
file_path = file if isinstance(file, str) else file.name
if Path(file_path).suffix.lower() in ['.jpg', '.jpeg', '.png']:
img = Image.open(file_path)
message_content.append({"type": "image", "image": img})
elif Path(file_path).suffix.lower() in ['.mp4', '.mov']:
frames = extract_video_frames(file_path)
for frame in frames:
message_content.append({"type": "image", "image": frame})
return message_content
def format_conversation_history(chat_history):
messages = []
current_user_content = []
for item in chat_history:
role = item["role"]
content = item["content"]
if role == "user":
if isinstance(content, str):
current_user_content.append({"type": "text", "text": content})
elif isinstance(content, list):
current_user_content.extend(content)
else:
current_user_content.append({"type": "text", "text": str(content)})
elif role == "assistant":
if current_user_content:
messages.append({"role": "user", "content": current_user_content})
current_user_content = []
messages.append({"role": "assistant", "content": [{"type": "text", "text": str(content)}]})
if current_user_content:
messages.append({"role": "user", "content": current_user_content})
return messages
@spaces.GPU() # duration=120
def generate_response(input_data, chat_history, max_new_tokens, system_prompt, temperature, top_p, top_k, repetition_penalty):
if isinstance(input_data, dict) and "text" in input_data:
text = input_data["text"]
files = input_data.get("files", [])
else:
text = str(input_data)
files = []
new_message_content = format_message(text, files)
new_message = {"role": "user", "content": new_message_content}
system_message = [{"role": "system", "content": [{"type": "text", "text": system_prompt}]}] if system_prompt else []
processed_history = format_conversation_history(chat_history)
messages = system_message + processed_history
if messages and messages[-1]["role"] == "user":
messages[-1]["content"].extend(new_message["content"])
else:
messages.append(new_message)
# Use the single Shako v4 model
inputs = processor.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
return_tensors="pt",
return_dict=True
).to(model.device)
streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
generation_kwargs = dict(
inputs,
streamer=streamer,
max_new_tokens=max_new_tokens,
do_sample=True,
temperature=temperature,
top_p=top_p,
top_k=top_k,
repetition_penalty=repetition_penalty
)
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
outputs = []
for text in streamer:
outputs.append(text)
yield "".join(outputs)
demo = gr.ChatInterface(
fn=generate_response,
additional_inputs=[
gr.Slider(label="Max new tokens", minimum=100, maximum=2000, step=1, value=512),
gr.Textbox(
label="System Prompt",
value="You are a friendly chatbot. ",
lines=4,
placeholder="Change system prompt"
),
gr.Slider(label="Temperature", minimum=0.1, maximum=2.0, step=0.1, value=0.7),
gr.Slider(label="Top-p", minimum=0.05, maximum=1.0, step=0.05, value=0.9),
gr.Slider(label="Top-k", minimum=1, maximum=100, step=1, value=50),
gr.Slider(label="Repetition Penalty", minimum=1.0, maximum=2.0, step=0.05, value=1.0),
],
examples=[
[{"text": "Explain this image", "files": ["examples/image1.jpg"]}],
],
cache_examples=False,
type="messages",
description="""
# شكو - Shako Iraqi AI
نموذج ذكاء عراقي متقدم يتحدث بالعراقي، يدعم الصور والفيديوهات والمحادثات الصوتية.
""",
fill_height=True,
textbox=gr.MultimodalTextbox(
label="Query Input",
file_types=["image", "video"],
file_count="multiple",
placeholder="Type your message or upload media"
),
stop_btn="Stop Generation",
multimodal=True,
theme=gr.themes.Soft(),
)
if __name__ == "__main__":
demo.launch()