Spaces:
Paused
Paused
| import os | |
| import uuid | |
| from typing import Dict, Optional | |
| from fastapi import FastAPI, HTTPException | |
| from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline, BitsAndBytesConfig # Import BitsAndBytesConfig | |
| import torch | |
| from pydantic import BaseModel | |
| import traceback | |
| from langchain.memory import ConversationBufferWindowMemory | |
| from langchain.chains import ConversationChain | |
| from langchain.prompts import PromptTemplate | |
| from starlette.responses import StreamingResponse | |
| import asyncio | |
| import json | |
| from langchain_community.llms import HuggingFacePipeline | |
| import uvicorn | |
| from huggingface_hub import login | |
| app = FastAPI() | |
| # Get the Hugging Face API token from environment variables (BEST PRACTICE) | |
| HUGGINGFACEHUB_API_TOKEN = os.environ.get("HUGGINGFACEHUB_API_TOKEN") | |
| if HUGGINGFACEHUB_API_TOKEN is None: | |
| raise ValueError("HUGGINGFACEHUB_API_TOKEN environment variable not set.") | |
| # --- Explicitly log in to Hugging Face Hub --- | |
| try: | |
| login(token=HUGGINGFACEHUB_API_TOKEN) | |
| print("Successfully logged into Hugging Face Hub.") | |
| except Exception as e: | |
| print(f"Failed to log into Hugging Face Hub: {e}") | |
| # --- Initialize tokenizer and model globally (heavy to load, shared across sessions) --- | |
| model_id = "mistralai/Mistral-7B-Instruct-v0.3" | |
| # --- NEW: Quantization configuration for 4-bit loading, optimized for T4 --- | |
| # This configuration tells Hugging Face Transformers to load the model weights | |
| # in 4-bit precision using the bitsandbytes library. | |
| bnb_config = BitsAndBytesConfig( | |
| load_in_4bit=True, # Enable 4-bit quantization | |
| bnb_4bit_quant_type="nf4", # Specify the quantization type: "nf4" (NormalFloat 4-bit) is recommended for transformers | |
| # --- IMPORTANT CHANGE: Use float16 for compute dtype for T4 compatibility --- | |
| # T4 GPUs (Turing architecture) do not have native bfloat16 support. | |
| # Using float16 for computations is more efficient and prevents CPU offloading. | |
| bnb_4bit_compute_dtype=torch.float16, | |
| bnb_4bit_use_double_quant=True, # Use double quantization for slightly better quality | |
| ) | |
| tokenizer = AutoTokenizer.from_pretrained(model_id, token=HUGGINGFACEHUB_API_TOKEN) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_id, | |
| device_map="auto", # 'auto' handles device placement, including offloading to CPU if necessary (but quantization aims to prevent this) | |
| quantization_config=bnb_config, # Pass the quantization configuration here | |
| # torch_dtype=torch.bfloat16, # REMOVED: This is now handled by bnb_4bit_compute_dtype | |
| trust_remote_code=True, | |
| token=HUGGINGFACEHUB_API_TOKEN | |
| ) | |
| # Global dictionary to store active conversation chains, keyed by session_id. | |
| # IMPORTANT: In a production environment, this in-memory dictionary will reset | |
| # if the server restarts. For true persistence, you would use a database (e.g., Redis, Firestore). | |
| active_conversations: Dict[str, ConversationChain] = {} | |
| # --- UPDATED PROMPT TEMPLATE --- | |
| template = """<|im_start|>system | |
| You are a concise and direct AI assistant named Siddhi. | |
| You strictly avoid asking any follow-up questions. | |
| You do not generate any additional conversational turns (e.g., "Human: ..."). | |
| If asked for your name, you respond with "I am Siddhi." | |
| If you do not know the answer to a question, you truthfully state that it does not know. | |
| <|im_end|> | |
| <|im_start|>user | |
| {history} | |
| {input}<|im_end|> | |
| <|im_start|>assistant | |
| """ | |
| PROMPT = PromptTemplate(input_variables=["history", "input"], template=template) | |
| class QuestionRequest(BaseModel): | |
| question: str | |
| session_id: Optional[str] = None # Optional session ID for continuing conversations | |
| class ChatResponse(BaseModel): | |
| response: str | |
| session_id: str # Include session_id in the response for client to track | |
| async def generate_text(request: QuestionRequest): | |
| """ | |
| Handles text generation requests, maintaining conversation history per session. | |
| """ | |
| session_id = request.session_id | |
| # If no session_id is provided, generate a new one. | |
| # This signifies the start of a new conversation. | |
| if session_id is None: | |
| session_id = str(uuid.uuid4()) | |
| print(f"Starting new conversation with session_id: {session_id}") | |
| # Retrieve or create a ConversationChain for this session_id | |
| if session_id not in active_conversations: | |
| print(f"Creating new ConversationChain for session_id: {session_id}") | |
| # Initialize Langchain HuggingFacePipeline for this session | |
| llm = HuggingFacePipeline(pipeline=pipeline( | |
| "text-generation", | |
| model=model, # Use the globally loaded model | |
| tokenizer=tokenizer, # Use the globally loaded tokenizer | |
| max_new_tokens=512, | |
| return_full_text=True, | |
| temperature=0.2, | |
| do_sample=True, | |
| )) | |
| # Initialize memory for this specific session | |
| memory = ConversationBufferWindowMemory(k=5) # Remembers the last 5 human-AI interaction pairs | |
| conversation = ConversationChain(llm=llm, memory=memory, prompt=PROMPT, verbose=True) | |
| active_conversations[session_id] = conversation | |
| else: | |
| print(f"Continuing conversation for session_id: {session_id}") | |
| conversation = active_conversations[session_id] | |
| async def generate_stream(): | |
| """ | |
| An asynchronous generator function to stream text responses token-by-token. | |
| Each yielded item will be a JSON string representing a part of the stream. | |
| """ | |
| # Flag to indicate when we've started streaming the AI's actual response | |
| started_streaming_ai_response = False | |
| try: | |
| # First, send a JSON object containing the session_id. | |
| # This allows the client to immediately get the session ID. | |
| yield json.dumps({"type": "session_info", "session_id": session_id}) + "\n" | |
| response_stream = conversation.stream({"input": request.question}) | |
| stop_sequences_to_check = ["Human:", "AI:", "\nHuman:", "\nAI:", "<|im_end|>"] | |
| assistant_start_marker = "<|im_start|>assistant\n" | |
| for chunk in response_stream: | |
| full_text_chunk = "" | |
| if 'response' in chunk: | |
| full_text_chunk = chunk['response'] | |
| else: | |
| full_text_chunk = str(chunk) | |
| if not started_streaming_ai_response: | |
| if assistant_start_marker in full_text_chunk: | |
| token_content = full_text_chunk.split(assistant_start_marker, 1)[1] | |
| started_streaming_ai_response = True | |
| else: | |
| token_content = "" | |
| else: | |
| token_content = full_text_chunk | |
| for stop_seq in stop_sequences_to_check: | |
| if stop_seq in token_content: | |
| token_content = token_content.split(stop_seq, 1)[0] | |
| if token_content: | |
| yield json.dumps({"type": "token", "content": token_content}) + "\n" | |
| await asyncio.sleep(0.01) | |
| yield json.dumps({"type": "end", "status": "completed", "session_id": session_id}) + "\n" | |
| return | |
| if token_content: | |
| yield json.dumps({"type": "token", "content": token_content}) + "\n" | |
| await asyncio.sleep(0.01) | |
| yield json.dumps({"type": "end", "status": "completed", "session_id": session_id}) + "\n" | |
| except Exception as e: | |
| print(f"Error during streaming generation for session {session_id}:") | |
| traceback.print_exc() | |
| yield json.dumps({"type": "error", "message": str(e), "session_id": session_id}) + "\n" | |
| # Return a StreamingResponse with application/json media type | |
| return StreamingResponse(generate_stream(), media_type="application/json") | |
| if __name__ == "__main__": | |
| uvicorn.run(app, host="0.0.0.0", port=int(os.environ.get("PORT", 7860))) | |