File size: 5,679 Bytes
d8f06d4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3596f15
d8f06d4
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
import os
import json
from typing import List, Dict, Any, Optional, Tuple

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline

from buffalo_rag.vector_store.db import VectorStore

class BuffaloRAG:
    def __init__(self, 
                 model_name: str = "Qwen/Qwen1.5-1.8B-Chat",
                 vector_store: Optional[VectorStore] = None):
        self.vector_store = vector_store or VectorStore()
        
        try:
            # Load model and tokenizer
            self.tokenizer = AutoTokenizer.from_pretrained(model_name)
            self.model = AutoModelForCausalLM.from_pretrained(
                model_name, 
                torch_dtype=torch.float16,
                device_map="auto",
                trust_remote_code=True,
                low_cpu_mem_usage=True
            )
            
            # More conservative generation parameters for stability
            self.pipe = pipeline(
                "text-generation",
                model=self.model,
                tokenizer=self.tokenizer,
                max_new_tokens=256,  # Shorter outputs for stability
                do_sample=False,     # Use greedy decoding instead of sampling
                pad_token_id=self.tokenizer.eos_token_id
            )
        except Exception as e:
            print(f"Error loading main model: {str(e)}")
            print("Falling back to smaller model...")
            # Fallback to a smaller, more stable model
            self.tokenizer = AutoTokenizer.from_pretrained("distilgpt2")
            self.model = AutoModelForCausalLM.from_pretrained("distilgpt2")
            self.pipe = pipeline(
                "text-generation",
                model=self.model,
                tokenizer=self.tokenizer,
                max_new_tokens=256
            )
    
    def retrieve(self, 
                query: str, 
                k: int = 5,
                filter_categories: Optional[List[str]] = None) -> List[Dict[str, Any]]:
        """Retrieve relevant chunks for a query."""
        return self.vector_store.hybrid_search(query, k=k, filter_categories=filter_categories)
    
    def format_context(self, results: List[Dict[str, Any]]) -> str:
        """Format retrieved results into context."""
        context = ""
        
        for i, result in enumerate(results):
            chunk = result['chunk']
            context += f"Source {i+1}: {chunk['title']}\n"
            context += f"URL: {chunk['url']}\n"
            context += f"Content: {chunk['content'][:500]}...\n\n"
        
        return context
    
    def generate_response(self, query: str, context: str) -> str:
        """Generate response using the language model with error handling."""
        prompt = f"""You are a friendly and professional counselor for international students at the University at Buffalo. Respond to the student's query in a supportive, detailed, and well-structured manner.

For your responses:
1. Address the student respectfully and empathetically
2. Provide clear, accurate information with specific details and steps when applicable
3. Organize your answer with appropriate headings, bullet points, or numbered lists when helpful
4. If the student's question is unclear or lacks essential details, ask 1-2 specific clarifying questions to better understand their situation
5. Include relevant deadlines, contacts, or resources when appropriate
6. Conclude with a brief encouraging statement
7. Only answer related to international students at UB, if it's not related to international students at UB, just say "I'm sorry, I don't have information about that."
8. Do not entertain any questions that are not related to students at UB.

Question: {query}

Relevant Information:
{context}

Answer:"""
        
        try:
            # Generate response
            response = self.pipe(prompt)[0]['generated_text']
            
            # Extract only the generated part (after the prompt)
            generated = response[len(prompt):].strip()
            
            return generated
        except Exception as e:
            print(f"Error during generation: {str(e)}")
            # Fallback response
            return "I'm sorry, I encountered an issue generating a response. Please try asking your question in a different way or contact UB International Student Services directly for assistance."
    
    def answer(self, 
              query: str, 
              k: int = 5,
              filter_categories: Optional[List[str]] = None) -> Dict[str, Any]:
        """End-to-end RAG pipeline."""
        # Retrieve relevant chunks
        results = self.retrieve(query, k=k, filter_categories=filter_categories)
        
        # Format context
        context = self.format_context(results)
        
        # Generate response
        response = self.generate_response(query, context)
        
        # Return response and sources
        return {
            'query': query,
            'response': response,
            'sources': [
                {
                    'title': result['chunk']['title'],
                    'url': result['chunk']['url'],
                    'score': result.get('rerank_score', result['score'])
                }
                for result in results
            ]
        }

# Example usage
if __name__ == "__main__":
    rag = BuffaloRAG(model_name="1bitLLM/bitnet_b1_58-large")
    response = rag.answer("How do I apply for OPT?")
    
    print(f"Query: {response['query']}")
    print(f"Response: {response['response']}")
    print("\nSources:")
    for source in response['sources']:
        print(f"- {source['title']} (Score: {source['score']:.4f})")